Skip to content

Commit 0eab0a6

Browse files
authored
fix!: properly nested update gateway configs (#3085)
1 parent 2664a6d commit 0eab0a6

3 files changed

Lines changed: 89 additions & 9 deletions

File tree

sqlmesh/core/config/base.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,15 @@ def update_field(
3636
3737
The updated field
3838
"""
39+
40+
def _update_pydantic_config(old: BaseConfig, new: BaseConfig) -> PydanticModel:
41+
if type(new) != type(old):
42+
raise ConfigError(
43+
"NESTED_UPDATE behavior requires both values to have the same type. "
44+
f"{type(old)} and {type(new)} were given instead."
45+
)
46+
return old.update_with(new)
47+
3948
if not old:
4049
return new
4150

@@ -78,18 +87,20 @@ def update_field(
7887

7988
return combined
8089
if update_strategy == UpdateStrategy.NESTED_UPDATE:
81-
if not isinstance(old, BaseConfig):
90+
if not isinstance(old, BaseConfig) and not isinstance(old, dict):
8291
raise ConfigError(
83-
f"NESTED_UPDATE behavior requires a config object. {type(old)} was given instead."
92+
f"NESTED_UPDATE behavior requires a config object and a dict of config objects as values. {type(old)} was given instead."
8493
)
8594

86-
if type(new) != type(old):
87-
raise ConfigError(
88-
"NESTED_UPDATE behavior requires both values to have the same type. "
89-
f"{type(old)} and {type(new)} were given instead."
90-
)
95+
if isinstance(old, dict):
96+
for k, pydantic_model in new.items():
97+
if k in old:
98+
old[k] = _update_pydantic_config(old[k], pydantic_model)
99+
else:
100+
old[k] = pydantic_model
91101

92-
return old.update_with(new)
102+
return old
103+
return _update_pydantic_config(old, new)
93104

94105
raise ConfigError(f"Unknown update strategy {update_strategy}.")
95106

sqlmesh/core/config/root.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ class Config(BaseConfig):
118118
disable_anonymized_analytics: bool = False
119119

120120
_FIELD_UPDATE_STRATEGY: t.ClassVar[t.Dict[str, UpdateStrategy]] = {
121-
"gateways": UpdateStrategy.KEY_UPDATE,
121+
"gateways": UpdateStrategy.NESTED_UPDATE,
122122
"notification_targets": UpdateStrategy.EXTEND,
123123
"ignore_patterns": UpdateStrategy.EXTEND,
124124
"users": UpdateStrategy.EXTEND,

tests/core/test_config.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
DuckDBConnectionConfig,
1313
GatewayConfig,
1414
ModelDefaultsConfig,
15+
BigQueryConnectionConfig,
16+
MotherDuckConnectionConfig,
1517
)
1618
from sqlmesh.core.config.connection import DuckDBAttachOptions
1719
from sqlmesh.core.config.feature_flag import DbtFeatureFlag, FeatureFlag
@@ -307,6 +309,73 @@ def test_load_config_from_env_invalid_variable_name():
307309
load_config_from_env()
308310

309311

312+
def test_load_yaml_config_env_var_gateway_override(tmp_path_factory):
313+
config_path = tmp_path_factory.mktemp("yaml_config") / "config.yaml"
314+
with open(config_path, "w", encoding="utf-8") as fd:
315+
fd.write(
316+
"""
317+
gateways:
318+
testing:
319+
connection:
320+
type: motherduck
321+
database: blah
322+
model_defaults:
323+
dialect: bigquery
324+
"""
325+
)
326+
with mock.patch.dict(
327+
os.environ,
328+
{
329+
"SQLMESH__GATEWAYS__TESTING__STATE_CONNECTION__TYPE": "bigquery",
330+
"SQLMESH__DEFAULT_GATEWAY": "testing",
331+
},
332+
):
333+
assert load_config_from_paths(
334+
Config,
335+
project_paths=[config_path],
336+
) == Config(
337+
gateways={
338+
"testing": GatewayConfig(
339+
connection=MotherDuckConnectionConfig(database="blah"),
340+
state_connection=BigQueryConnectionConfig(),
341+
),
342+
},
343+
model_defaults=ModelDefaultsConfig(dialect="bigquery"),
344+
default_gateway="testing",
345+
)
346+
347+
348+
def test_load_py_config_env_var_gateway_override(tmp_path_factory):
349+
config_path = tmp_path_factory.mktemp("python_config") / "config.py"
350+
with open(config_path, "w", encoding="utf-8") as fd:
351+
fd.write(
352+
"""from sqlmesh.core.config import Config, DuckDBConnectionConfig, GatewayConfig, ModelDefaultsConfig
353+
config = Config(gateways={"duckdb_gateway": GatewayConfig(connection=DuckDBConnectionConfig())}, model_defaults=ModelDefaultsConfig(dialect=''))
354+
"""
355+
)
356+
with mock.patch.dict(
357+
os.environ,
358+
{
359+
"SQLMESH__GATEWAYS__DUCKDB_GATEWAY__STATE_CONNECTION__TYPE": "bigquery",
360+
"SQLMESH__DEFAULT_GATEWAY": "duckdb_gateway",
361+
},
362+
):
363+
config = load_config_from_paths(
364+
Config,
365+
project_paths=[config_path],
366+
)
367+
assert config == Config(
368+
gateways={ # type: ignore
369+
"duckdb_gateway": GatewayConfig(
370+
connection=DuckDBConnectionConfig(),
371+
state_connection=BigQueryConnectionConfig(),
372+
),
373+
},
374+
model_defaults=ModelDefaultsConfig(dialect=""),
375+
default_gateway="duckdb_gateway",
376+
)
377+
378+
310379
def test_load_config_from_python_module_missing_config(tmp_path):
311380
config_path = tmp_path / "missing_config.py"
312381
with open(config_path, "w", encoding="utf-8") as fd:

0 commit comments

Comments
 (0)