Skip to content

Commit 7b1b78a

Browse files
Fix: Adapt get_dbt_version tuple to include patch in version (#4619)
1 parent e396c8c commit 7b1b78a

File tree

6 files changed

+28
-24
lines changed

6 files changed

+28
-24
lines changed

sqlmesh/dbt/builtin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
class Exceptions:
3232
def raise_compiler_error(self, msg: str) -> None:
33-
if DBT_VERSION >= (1, 4):
33+
if DBT_VERSION >= (1, 4, 0):
3434
from dbt.exceptions import CompilationError
3535

3636
raise CompilationError(msg)

sqlmesh/dbt/manifest.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def _load_tests(self) -> None:
238238
dependencies.macros.append(MacroReference(package="dbt", name="get_where_subquery"))
239239
dependencies.macros.append(MacroReference(package="dbt", name="should_store_failures"))
240240

241-
sql = node.raw_code if DBT_VERSION >= (1, 3) else node.raw_sql # type: ignore
241+
sql = node.raw_code if DBT_VERSION >= (1, 3, 0) else node.raw_sql # type: ignore
242242
dependencies = dependencies.union(self._extra_dependencies(sql, node.package_name))
243243
dependencies = dependencies.union(
244244
self._flatten_dependencies_from_macros(dependencies.macros, node.package_name)
@@ -278,7 +278,7 @@ def _load_models_and_seeds(self) -> None:
278278
node_name = f"{node_name}_v{node_version}"
279279

280280
if node.resource_type in {"model", "snapshot"}:
281-
sql = node.raw_code if DBT_VERSION >= (1, 3) else node.raw_sql # type: ignore
281+
sql = node.raw_code if DBT_VERSION >= (1, 3, 0) else node.raw_sql # type: ignore
282282
dependencies = Dependencies(
283283
macros=macro_references, refs=_refs(node), sources=_sources(node)
284284
)
@@ -305,7 +305,7 @@ def _load_on_run_start_end(self) -> None:
305305
if node.resource_type == "operation" and (
306306
set(node.tags) & {"on-run-start", "on-run-end"}
307307
):
308-
sql = node.raw_code if DBT_VERSION >= (1, 3) else node.raw_sql # type: ignore
308+
sql = node.raw_code if DBT_VERSION >= (1, 3, 0) else node.raw_sql # type: ignore
309309
node_name = node.name
310310
node_path = Path(node.original_file_path)
311311

@@ -339,7 +339,7 @@ def _load_manifest(self) -> Manifest:
339339

340340
variables = (
341341
self.variable_overrides
342-
if DBT_VERSION >= (1, 5)
342+
if DBT_VERSION >= (1, 5, 0)
343343
else json.dumps(self.variable_overrides)
344344
)
345345

@@ -354,7 +354,7 @@ def _load_manifest(self) -> Manifest:
354354
)
355355
flags.set_from_args(args, None)
356356

357-
if DBT_VERSION >= (1, 8):
357+
if DBT_VERSION >= (1, 8, 0):
358358
from dbt_common.context import set_invocation_context # type: ignore
359359

360360
set_invocation_context(os.environ)
@@ -371,7 +371,7 @@ def _load_manifest(self) -> Manifest:
371371

372372
self._project_name = project.project_name
373373

374-
if DBT_VERSION >= (1, 8):
374+
if DBT_VERSION >= (1, 8, 0):
375375
from dbt.mp_context import get_mp_context # type: ignore
376376

377377
register_adapter(runtime_config, get_mp_context()) # type: ignore
@@ -546,7 +546,7 @@ def _macro_references(
546546

547547

548548
def _refs(node: ManifestNode) -> t.Set[str]:
549-
if DBT_VERSION >= (1, 5):
549+
if DBT_VERSION >= (1, 5, 0):
550550
result = set()
551551
for r in node.refs:
552552
ref_name = f"{r.package}.{r.name}" if r.package else r.name

sqlmesh/dbt/relation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from sqlmesh.dbt.util import DBT_VERSION
22

33

4-
if DBT_VERSION < (1, 8):
4+
if DBT_VERSION < (1, 8, 0):
55
from dbt.contracts.relation import * # type: ignore # noqa: F403
66
else:
77
from dbt.adapters.contracts.relation import * # type: ignore # noqa: F403

sqlmesh/dbt/target.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def validate_authentication(cls, data: t.Any) -> t.Any:
167167
if not isinstance(data, dict):
168168
return data
169169

170-
if "database" not in data and DBT_VERSION >= (1, 5):
170+
if "database" not in data and DBT_VERSION >= (1, 5, 0):
171171
path = data.get("path")
172172
data["database"] = (
173173
"memory"
@@ -424,7 +424,7 @@ def relation_class(cls) -> t.Type[BaseRelation]:
424424

425425
@classproperty
426426
def column_class(cls) -> t.Type[Column]:
427-
if DBT_VERSION < (1, 6):
427+
if DBT_VERSION < (1, 6, 0):
428428
from dbt.adapters.redshift import RedshiftColumn # type: ignore
429429

430430
return RedshiftColumn

sqlmesh/dbt/util.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,18 @@
55
from dbt.version import get_installed_version
66

77

8-
def _get_dbt_version() -> t.Tuple[int, int]:
8+
def _get_dbt_version() -> t.Tuple[int, int, int]:
99
dbt_version = get_installed_version()
10-
return (int(dbt_version.major or "0"), int(dbt_version.minor or "0"))
10+
return (
11+
int(dbt_version.major or "0"),
12+
int(dbt_version.minor or "0"),
13+
int(dbt_version.patch or "0"),
14+
)
1115

1216

1317
DBT_VERSION = _get_dbt_version()
1418

15-
if DBT_VERSION < (1, 8):
19+
if DBT_VERSION < (1, 8, 0):
1620
from dbt.clients.agate_helper import table_from_data_flat, empty_table, as_matrix # type: ignore # noqa: F401
1721
else:
1822
from dbt_common.clients.agate_helper import table_from_data_flat, empty_table, as_matrix # type: ignore # noqa: F401

tests/utils/test_metaprogramming.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def other_func(a: int) -> int:
108108

109109

110110
@contextmanager
111-
def test_context_manager():
111+
def sample_context_manager():
112112
yield
113113

114114

@@ -141,7 +141,7 @@ def main_func(y: int, foo=exp.true(), *, bar=expressions.Literal.number(1) + 2)
141141
def closure(z: int) -> int:
142142
return z + Z
143143

144-
with test_context_manager():
144+
with sample_context_manager():
145145
pass
146146

147147
return closure(y) + other_func(Y)
@@ -171,7 +171,7 @@ def test_func_globals() -> None:
171171
"exp": exp,
172172
"expressions": exp,
173173
"fetch_data": fetch_data,
174-
"test_context_manager": test_context_manager,
174+
"sample_context_manager": sample_context_manager,
175175
"function_with_custom_decorator": function_with_custom_decorator,
176176
"SQLGLOT_META": SQLGLOT_META,
177177
}
@@ -211,7 +211,7 @@ def test_normalize_source() -> None:
211211
212212
def closure(z: int):
213213
return z + Z
214-
with test_context_manager():
214+
with sample_context_manager():
215215
pass
216216
return closure(y) + other_func(Y)"""
217217
)
@@ -261,7 +261,7 @@ def test_serialize_env() -> None:
261261
262262
def closure(z: int):
263263
return z + Z
264-
with test_context_manager():
264+
with sample_context_manager():
265265
pass
266266
return closure(y) + other_func(Y)""",
267267
),
@@ -318,9 +318,9 @@ def baz(self):
318318
),
319319
"func": Executable(
320320
payload="""@contextmanager
321-
def test_context_manager():
321+
def sample_context_manager():
322322
yield""",
323-
name="test_context_manager",
323+
name="sample_context_manager",
324324
path="test_metaprogramming.py",
325325
alias="func",
326326
),
@@ -344,11 +344,11 @@ def test_context_manager():
344344
my_lambda()
345345
return X + a + W""",
346346
),
347-
"test_context_manager": Executable(
347+
"sample_context_manager": Executable(
348348
payload="""@contextmanager
349-
def test_context_manager():
349+
def sample_context_manager():
350350
yield""",
351-
name="test_context_manager",
351+
name="sample_context_manager",
352352
path="test_metaprogramming.py",
353353
),
354354
"wraps": Executable(payload="from functools import wraps", kind=ExecutableKind.IMPORT),

0 commit comments

Comments
 (0)