Skip to content

Commit 1da137a

Browse files
authored
refactor(lsp): refactor lsp calls into context (#4616)
1 parent b5989c8 commit 1da137a

File tree

2 files changed

+106
-81
lines changed

2 files changed

+106
-81
lines changed

sqlmesh/lsp/context.py

Lines changed: 92 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import typing as t
55

66
from sqlmesh.core.model.definition import SqlModel
7+
from sqlmesh.core.linter.definition import AnnotatedRuleViolation
78
from sqlmesh.lsp.custom import RenderModelEntry
89
from sqlmesh.lsp.uri import URI
910

@@ -28,8 +29,14 @@ class LSPContext:
2829
model names and standalone audit names.
2930
"""
3031

32+
map: t.Dict[Path, t.Union[ModelTarget, AuditTarget]]
33+
_render_cache: t.Dict[Path, t.List[RenderModelEntry]]
34+
_lint_cache: t.Dict[Path, t.List[AnnotatedRuleViolation]]
35+
3136
def __init__(self, context: Context) -> None:
3237
self.context = context
38+
self._render_cache = {}
39+
self._lint_cache = {}
3340

3441
# Add models to the map
3542
model_map: t.Dict[Path, ModelTarget] = {}
@@ -54,36 +61,90 @@ def __init__(self, context: Context) -> None:
5461
**audit_map,
5562
}
5663

64+
def render_model(self, uri: URI) -> t.List[RenderModelEntry]:
65+
"""Get rendered models for a file, using cache when available.
5766
58-
def render_model(context: LSPContext, uri: URI) -> t.Iterator[RenderModelEntry]:
59-
target = context.map[uri.to_path()]
60-
if isinstance(target, AuditTarget):
61-
audit = context.context.standalone_audits[target.name]
62-
definition = audit.render_definition(
63-
include_python=False,
64-
render_query=True,
65-
)
66-
rendered_query = [render.sql(dialect=audit.dialect, pretty=True) for render in definition]
67-
yield RenderModelEntry(
68-
name=audit.name,
69-
fqn=audit.fqn,
70-
description=audit.description,
71-
rendered_query="\n\n".join(rendered_query),
72-
)
73-
if isinstance(target, ModelTarget):
74-
for name in target.names:
75-
model = context.context.get_model(name)
76-
if isinstance(model, SqlModel):
77-
rendered_query = [
78-
render.sql(dialect=model.dialect, pretty=True)
79-
for render in model.render_definition(
80-
include_python=False,
81-
render_query=True,
67+
Args:
68+
uri: The URI of the file to render.
69+
70+
Returns:
71+
List of rendered model entries.
72+
"""
73+
path = uri.to_path()
74+
75+
# Check cache first
76+
if path in self._render_cache:
77+
return self._render_cache[path]
78+
79+
# If not cached, render and cache
80+
entries: t.List[RenderModelEntry] = []
81+
target = self.map.get(path)
82+
83+
if isinstance(target, AuditTarget):
84+
audit = self.context.standalone_audits[target.name]
85+
definition = audit.render_definition(
86+
include_python=False,
87+
render_query=True,
88+
)
89+
rendered_query = [
90+
render.sql(dialect=audit.dialect, pretty=True) for render in definition
91+
]
92+
entry = RenderModelEntry(
93+
name=audit.name,
94+
fqn=audit.fqn,
95+
description=audit.description,
96+
rendered_query="\n\n".join(rendered_query),
97+
)
98+
entries.append(entry)
99+
100+
elif isinstance(target, ModelTarget):
101+
for name in target.names:
102+
model = self.context.get_model(name)
103+
if isinstance(model, SqlModel):
104+
rendered_query = [
105+
render.sql(dialect=model.dialect, pretty=True)
106+
for render in model.render_definition(
107+
include_python=False,
108+
render_query=True,
109+
)
110+
]
111+
entry = RenderModelEntry(
112+
name=model.name,
113+
fqn=model.fqn,
114+
description=model.description,
115+
rendered_query="\n\n".join(rendered_query),
82116
)
83-
]
84-
yield RenderModelEntry(
85-
name=model.name,
86-
fqn=model.fqn,
87-
description=model.description,
88-
rendered_query="\n\n".join(rendered_query),
89-
)
117+
entries.append(entry)
118+
119+
# Store in cache
120+
self._render_cache[path] = entries
121+
return entries
122+
123+
def lint_model(self, uri: URI) -> t.List[AnnotatedRuleViolation]:
124+
"""Get lint diagnostics for a model, using cache when available.
125+
126+
Args:
127+
uri: The URI of the file to lint.
128+
129+
Returns:
130+
List of annotated rule violations.
131+
"""
132+
path = uri.to_path()
133+
134+
# Check cache first
135+
if path in self._lint_cache:
136+
return self._lint_cache[path]
137+
138+
# If not cached, lint and cache
139+
target = self.map.get(path)
140+
if target is None or not isinstance(target, ModelTarget):
141+
return []
142+
143+
diagnostics = self.context.lint_models(
144+
target.names,
145+
raise_on_error=False,
146+
)
147+
148+
# Store in cache
149+
self._lint_cache[path] = diagnostics
150+
return diagnostics

sqlmesh/lsp/main.py

Lines changed: 14 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from sqlmesh.lsp.context import (
2525
LSPContext,
2626
ModelTarget,
27-
render_model as render_model_context,
2827
)
2928
from sqlmesh.lsp.custom import (
3029
ALL_MODELS_FEATURE,
@@ -58,10 +57,6 @@ def __init__(
5857
self.context_class = context_class
5958
self.lsp_context: t.Optional[LSPContext] = None
6059

61-
# Cache stores tuples of (diagnostics, diagnostic_version)
62-
self.lint_cache: t.Dict[URI, t.Tuple[t.List[AnnotatedRuleViolation], int]] = {}
63-
self._diagnostic_version_counter: int = 0
64-
6560
self.client_supports_pull_diagnostics = False
6661
# Register LSP features (e.g., formatting, hover, etc.)
6762
self._register_features()
@@ -120,7 +115,7 @@ def all_models(ls: LanguageServer, params: AllModelsRequest) -> AllModelsRespons
120115
def render_model(ls: LanguageServer, params: RenderModelRequest) -> RenderModelResponse:
121116
uri = URI(params.textDocumentUri)
122117
context = self._context_get_or_load(uri)
123-
return RenderModelResponse(models=list(render_model_context(context, uri)))
118+
return RenderModelResponse(models=context.render_model(uri))
124119

125120
@self.server.feature(API_FEATURE)
126121
def api(ls: LanguageServer, request: ApiRequest) -> t.Dict[str, t.Any]:
@@ -173,17 +168,11 @@ def did_open(ls: LanguageServer, params: types.DidOpenTextDocumentParams) -> Non
173168
if models is None or not isinstance(models, ModelTarget):
174169
return
175170

176-
if self.lint_cache.get(uri) is None:
177-
diagnostics = context.context.lint_models(
178-
models.names,
179-
raise_on_error=False,
180-
)
181-
self._diagnostic_version_counter += 1
182-
self.lint_cache[uri] = (diagnostics, self._diagnostic_version_counter)
171+
# Get diagnostics from context (which handles caching)
172+
diagnostics = context.lint_model(uri)
183173

184174
# Only publish diagnostics if client doesn't support pull diagnostics
185175
if not self.client_supports_pull_diagnostics:
186-
diagnostics, _ = self.lint_cache[uri]
187176
ls.publish_diagnostics(
188177
params.text_document.uri,
189178
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(diagnostics),
@@ -197,13 +186,8 @@ def did_change(ls: LanguageServer, params: types.DidChangeTextDocumentParams) ->
197186
if models is None or not isinstance(models, ModelTarget):
198187
return
199188

200-
# Always update the cache
201-
diagnostics = context.context.lint_models(
202-
models.names,
203-
raise_on_error=False,
204-
)
205-
self._diagnostic_version_counter += 1
206-
self.lint_cache[uri] = (diagnostics, self._diagnostic_version_counter)
189+
# Get diagnostics from context (which handles caching)
190+
diagnostics = context.lint_model(uri)
207191

208192
# Only publish diagnostics if client doesn't support pull diagnostics
209193
if not self.client_supports_pull_diagnostics:
@@ -220,13 +204,8 @@ def did_save(ls: LanguageServer, params: types.DidSaveTextDocumentParams) -> Non
220204
if models is None or not isinstance(models, ModelTarget):
221205
return
222206

223-
# Always update the cache
224-
diagnostics = context.context.lint_models(
225-
models.names,
226-
raise_on_error=False,
227-
)
228-
self._diagnostic_version_counter += 1
229-
self.lint_cache[uri] = (diagnostics, self._diagnostic_version_counter)
207+
# Get diagnostics from context (which handles caching)
208+
diagnostics = context.lint_model(uri)
230209

231210
# Only publish diagnostics if client doesn't support pull diagnostics
232211
if not self.client_supports_pull_diagnostics:
@@ -445,31 +424,16 @@ def workspace_diagnostic(
445424
return types.WorkspaceDiagnosticReport(items=[])
446425

447426
def _get_diagnostics_for_uri(self, uri: URI) -> t.Tuple[t.List[types.Diagnostic], int]:
448-
"""Get diagnostics for a specific URI, returning (diagnostics, result_id)."""
449-
# Check if we have cached diagnostics
450-
if uri in self.lint_cache:
451-
diagnostics, result_id = self.lint_cache[uri]
452-
return SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(diagnostics), result_id
427+
"""Get diagnostics for a specific URI, returning (diagnostics, result_id).
453428
454-
# Try to get diagnostics by loading context and linting
429+
Since we no longer track version numbers, we always return 0 as the result_id.
430+
This means pull diagnostics will always fetch fresh results.
431+
"""
455432
try:
456433
context = self._context_get_or_load(uri)
457-
models = context.map[uri.to_path()]
458-
if models is None or not isinstance(models, ModelTarget):
459-
return [], 0
460-
461-
# Lint the models and cache the results
462-
diagnostics = context.context.lint_models(
463-
models.names,
464-
raise_on_error=False,
465-
)
466-
self._diagnostic_version_counter += 1
467-
self.lint_cache[uri] = (diagnostics, self._diagnostic_version_counter)
468-
return SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(
469-
diagnostics
470-
), self._diagnostic_version_counter
434+
diagnostics = context.lint_model(uri)
435+
return SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(diagnostics), 0
471436
except Exception:
472-
# If we can't get diagnostics, return empty list with no result ID
473437
return [], 0
474438

475439
def _context_get_or_load(self, document_uri: URI) -> LSPContext:
@@ -523,7 +487,7 @@ def _ensure_context_for_document(
523487
created_context = self.context_class(paths=[path])
524488
self.lsp_context = LSPContext(created_context)
525489
loaded = True
526-
# Re-check context for document now that it's loaded
490+
# Re-check context for the document now that it's loaded
527491
return self._ensure_context_for_document(document_uri)
528492
except Exception as e:
529493
self.server.show_message(

0 commit comments

Comments
 (0)