Skip to content

Commit 6174549

Browse files
committed
feat: add the ability to return range for key and value
1 parent d8aa539 commit 6174549

2 files changed

Lines changed: 165 additions & 18 deletions

File tree

sqlmesh/core/linter/helpers.py

Lines changed: 106 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,11 @@ def get_range_of_a_key_in_model_block(
165165
sql: str,
166166
dialect: str,
167167
key: str,
168-
) -> t.Optional[Range]:
168+
) -> t.Optional[t.Tuple[Range, Range]]:
169169
"""
170-
Get the range of a specific key in the model block of an SQL file.
170+
Get the ranges of a specific key and its value in the MODEL block of an SQL file.
171+
172+
Returns a tuple of (key_range, value_range) if found, otherwise None.
171173
"""
172174
tokens = tokenize(sql, dialect=dialect)
173175
if not tokens:
@@ -237,17 +239,110 @@ def get_range_of_a_key_in_model_block(
237239
if depth == 1 and tt is TokenType.VAR and tok.text.upper() == key.upper():
238240
# Validate key position: it should immediately follow '(' or ',' at top level
239241
prev_idx = i - 1
240-
# Skip over non-significant tokens we don't want to gate on (e.g., comments)
242+
# Skip comments
241243
while prev_idx >= 0 and tokens[prev_idx].token_type in (TokenType.COMMENT,):
242244
prev_idx -= 1
243245
prev_tt = tokens[prev_idx].token_type if prev_idx >= 0 else None
244-
if prev_tt in (TokenType.L_PAREN, TokenType.COMMA):
245-
position = TokenPositionDetails(
246-
line=tok.line,
247-
col=tok.col,
248-
start=tok.start,
249-
end=tok.end,
250-
)
251-
return position.to_range(sql.splitlines())
246+
if prev_tt not in (TokenType.L_PAREN, TokenType.COMMA):
247+
continue
248+
249+
# Key range
250+
lines = sql.splitlines()
251+
key_start = TokenPositionDetails(
252+
line=tok.line, col=tok.col, start=tok.start, end=tok.end
253+
)
254+
key_range = key_start.to_range(lines)
255+
256+
# Find value start: the next non-comment token after the key
257+
value_start_idx = i + 1
258+
while value_start_idx < rparen_idx and tokens[value_start_idx].token_type in (
259+
TokenType.COMMENT,
260+
):
261+
value_start_idx += 1
262+
if value_start_idx >= rparen_idx:
263+
return None
264+
265+
# Walk to the end of the value expression: until top-level comma or closing paren
266+
# Track internal nesting for (), [], {}
267+
nested = 0
268+
j = value_start_idx
269+
value_end_idx = value_start_idx
270+
271+
def is_open(t: TokenType) -> bool:
272+
return t in (TokenType.L_PAREN, TokenType.L_BRACE, TokenType.L_BRACKET)
273+
274+
def is_close(t: TokenType) -> bool:
275+
return t in (TokenType.R_PAREN, TokenType.R_BRACE, TokenType.R_BRACKET)
276+
277+
while j < rparen_idx:
278+
ttype = tokens[j].token_type
279+
if ttype is TokenType.COMMENT:
280+
j += 1
281+
continue
282+
if is_open(ttype):
283+
nested += 1
284+
elif is_close(ttype):
285+
nested -= 1
286+
287+
# End of value: at top-level (nested == 0) encountering a comma or the end paren
288+
if nested == 0 and (
289+
ttype is TokenType.COMMA or (ttype is TokenType.R_PAREN and depth == 1)
290+
):
291+
# For comma, don't include it in the value range
292+
# For closing paren, include it only if it's part of the value structure
293+
if ttype is TokenType.COMMA:
294+
# Don't include the comma in the value range
295+
break
296+
else:
297+
# Include the closing parenthesis in the value range
298+
value_end_idx = j
299+
break
300+
301+
value_end_idx = j
302+
j += 1
303+
304+
# Special case: if the value ends with a closing parenthesis that's part of the value
305+
# (not the MODEL block's closing parenthesis), we need to include it
306+
if value_end_idx < rparen_idx - 1:
307+
next_token = tokens[value_end_idx + 1]
308+
if next_token.token_type is TokenType.COMMA:
309+
# Value ends before the comma, which is correct
310+
pass
311+
elif next_token.token_type is TokenType.R_PAREN and depth == 1:
312+
# This is the MODEL block's closing parenthesis, don't include it
313+
pass
314+
else:
315+
# Check if we should extend the range to include more tokens
316+
# This handles cases like incomplete parsing
317+
pass
318+
319+
# Trim trailing comments from value end
320+
while (
321+
value_end_idx > value_start_idx
322+
and tokens[value_end_idx].token_type is TokenType.COMMENT
323+
):
324+
value_end_idx -= 1
325+
326+
value_start_tok = tokens[value_start_idx]
327+
value_end_tok = tokens[value_end_idx]
328+
329+
value_start_pos = TokenPositionDetails(
330+
line=value_start_tok.line,
331+
col=value_start_tok.col,
332+
start=value_start_tok.start,
333+
end=value_start_tok.end,
334+
)
335+
value_end_pos = TokenPositionDetails(
336+
line=value_end_tok.line,
337+
col=value_end_tok.col,
338+
start=value_end_tok.start,
339+
end=value_end_tok.end,
340+
)
341+
value_range = Range(
342+
start=value_start_pos.to_range(lines).start,
343+
end=value_end_pos.to_range(lines).end,
344+
)
345+
346+
return (key_range, value_range)
252347

253348
return None

tests/core/linter/test_helpers.py

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,17 @@ def test_get_range_of_a_key_in_model_block_testing_on_sushi():
5252
]
5353
assert len(sql_models) > 0
5454

55+
# Test that the function works for all keys in the model block
5556
for model in sql_models:
56-
possible_keys = ["name", "tags", "description", "columns", "owner", "cron", "dialect"]
57+
possible_keys = [
58+
"name",
59+
"tags",
60+
"description",
61+
"column_descriptions",
62+
"owner",
63+
"cron",
64+
"dialect",
65+
]
5766

5867
dialect = model.dialect
5968
assert dialect is not None
@@ -67,12 +76,55 @@ def test_get_range_of_a_key_in_model_block_testing_on_sushi():
6776
count_properties_checked = 0
6877

6978
for key in possible_keys:
70-
range = get_range_of_a_key_in_model_block(content, dialect, key)
71-
72-
# Check that the range starts with the key and ends with ;
73-
if range:
74-
read_range = read_range_from_file(path, range)
75-
assert read_range.lower() == key.lower()
79+
ranges = get_range_of_a_key_in_model_block(content, dialect, key)
80+
81+
if ranges:
82+
key_range, value_range = ranges
83+
read_key = read_range_from_file(path, key_range)
84+
assert read_key.lower() == key.lower()
85+
# Value range should be non-empty
86+
read_value = read_range_from_file(path, value_range)
87+
assert len(read_value) > 0
7688
count_properties_checked += 1
7789

7890
assert count_properties_checked > 0
91+
92+
# Test that the function works for different kind of value blocks
93+
tests = [
94+
("sushi.customers", "name", "sushi.customers"),
95+
(
96+
"sushi.customers",
97+
"tags",
98+
"(pii, fact)",
99+
),
100+
("sushi.customers", "description", "'Sushi customer data'"),
101+
(
102+
"sushi.customers",
103+
"column_descriptions",
104+
"( customer_id = 'customer_id uniquely identifies customers' )",
105+
),
106+
("sushi.customers", "owner", "jen"),
107+
("sushi.customers", "cron", "'@daily'"),
108+
]
109+
for model_name, key, value in tests:
110+
model = context.get_model(model_name)
111+
assert model is not None
112+
113+
dialect = model.dialect
114+
assert dialect is not None
115+
116+
path = model._path
117+
assert path is not None
118+
119+
with open(path, "r", encoding="utf-8") as file:
120+
content = file.read()
121+
122+
ranges = get_range_of_a_key_in_model_block(content, dialect, key)
123+
assert ranges is not None, f"Could not find key '{key}' in model '{model_name}'"
124+
125+
key_range, value_range = ranges
126+
read_key = read_range_from_file(path, key_range)
127+
assert read_key.lower() == key.lower()
128+
129+
read_value = read_range_from_file(path, value_range)
130+
assert read_value == value

0 commit comments

Comments
 (0)