From ee1b056bf6338158bd8e42f9b4a12188c6c0408c Mon Sep 17 00:00:00 2001 From: Andrzej J Skalski Date: Thu, 5 Feb 2026 17:08:36 +0100 Subject: [PATCH 1/2] fix(langserver): support go-to-definition for plugin-defined rules Previously, go-to-definition only worked for core builtin functions. Plugin-defined rules like go_library, go_repo, etc. would return no results because they were parsed by a different parser instance than the one used by the language server. Changes: - Use parse.InitParser() to initialize the parser on BuildState, then get the same parser via parse.GetAspParser() for the language server - Add periodic loading of function definitions (every 2 seconds) so go-to-definition works progressively while the full parse runs - Add Range() method to cmap types to iterate over parsed ASTs - Add AllFunctionsByFile() to asp.Parser to retrieve function definitions - Fix file URIs to use absolute paths --- src/cmap/cerrmap.go | 12 +++++ src/cmap/cmap.go | 26 +++++++++++ src/parse/asp/parser.go | 19 ++++++++ src/parse/init.go | 13 ++++++ tools/build_langserver/lsp/BUILD | 1 + tools/build_langserver/lsp/definition.go | 34 +++++++++----- tools/build_langserver/lsp/lsp.go | 58 +++++++++++++++++++++++- 7 files changed, 149 insertions(+), 14 deletions(-) diff --git a/src/cmap/cerrmap.go b/src/cmap/cerrmap.go index 687c9191bf..b29acebf1f 100644 --- a/src/cmap/cerrmap.go +++ b/src/cmap/cerrmap.go @@ -78,3 +78,15 @@ func (m *ErrMap[K, V]) GetOrSet(key K, f func() (V, error)) (V, error) { } return v.Val, v.Err } + +// Range calls f for each key-value pair in the map. +// If f returns false, iteration stops. +// No particular consistency guarantees are made during iteration. +func (m *ErrMap[K, V]) Range(f func(key K, val V) bool) { + m.m.Range(func(key K, val errV[V]) bool { + if val.Err != nil { + return true // skip errors + } + return f(key, val.Val) + }) +} diff --git a/src/cmap/cmap.go b/src/cmap/cmap.go index ce8508b454..0f0d8b0c4f 100644 --- a/src/cmap/cmap.go +++ b/src/cmap/cmap.go @@ -94,6 +94,17 @@ func (m *Map[K, V]) Values() []V { return ret } +// Range calls f for each key-value pair in the map. +// If f returns false, iteration stops. +// No particular consistency guarantees are made during iteration. +func (m *Map[K, V]) Range(f func(key K, val V) bool) { + for i := 0; i < len(m.shards); i++ { + if !m.shards[i].Range(f) { + return + } + } +} + // An awaitableValue represents a value in the map & an awaitable channel for it to exist. type awaitableValue[V any] struct { Val V @@ -195,3 +206,18 @@ func (s *shard[K, V]) Contains(key K) bool { _, ok := s.m[key] return ok } + +// Range calls f for each key-value pair in this shard. +// Returns false if iteration was stopped early. +func (s *shard[K, V]) Range(f func(key K, val V) bool) bool { + s.l.RLock() + defer s.l.RUnlock() + for k, v := range s.m { + if v.Wait == nil { // Only include completed values + if !f(k, v.Val) { + return false + } + } + } + return true +} diff --git a/src/parse/asp/parser.go b/src/parse/asp/parser.go index 36f055b966..8d104a8691 100644 --- a/src/parse/asp/parser.go +++ b/src/parse/asp/parser.go @@ -257,6 +257,25 @@ func (p *Parser) optimiseBuiltinCalls(stmts []*Statement) { } } +// AllFunctionsByFile returns all function definitions grouped by filename. +// This includes functions from builtins, plugins, and subincludes. +// It iterates over the ASTs stored by the interpreter. +func (p *Parser) AllFunctionsByFile() map[string][]*Statement { + if p.interpreter == nil || p.interpreter.asts == nil { + return nil + } + result := make(map[string][]*Statement) + p.interpreter.asts.Range(func(filename string, stmts []*Statement) bool { + for _, stmt := range stmts { + if stmt.FuncDef != nil { + result[filename] = append(result[filename], stmt) + } + } + return true + }) + return result +} + // whitelistedKwargs returns true if the given built-in function name is allowed to // be called as non-kwargs. // TODO(peterebden): Come up with a syntax that exposes this directly in the file. diff --git a/src/parse/init.go b/src/parse/init.go index 663e265104..ee67dacda5 100644 --- a/src/parse/init.go +++ b/src/parse/init.go @@ -25,6 +25,19 @@ func InitParser(state *core.BuildState) *core.BuildState { return state } +// GetAspParser returns the underlying asp.Parser from the state's parser. +// This is useful for tools like the language server that need direct access to AST information. +// Returns nil if the state's parser is not set or is not an aspParser. +func GetAspParser(state *core.BuildState) *asp.Parser { + if state.Parser == nil { + return nil + } + if ap, ok := state.Parser.(*aspParser); ok { + return ap.parser + } + return nil +} + // aspParser implements the core.Parser interface around our parser package. type aspParser struct { parser *asp.Parser diff --git a/tools/build_langserver/lsp/BUILD b/tools/build_langserver/lsp/BUILD index f8c5f3485c..d0be35f006 100644 --- a/tools/build_langserver/lsp/BUILD +++ b/tools/build_langserver/lsp/BUILD @@ -17,6 +17,7 @@ go_library( "//rules", "//src/core", "//src/fs", + "//src/parse", "//src/parse/asp", "//src/plz", "//tools/build_langserver/lsp/astutils", diff --git a/tools/build_langserver/lsp/definition.go b/tools/build_langserver/lsp/definition.go index 9ee1c5df18..50c0fdd51c 100644 --- a/tools/build_langserver/lsp/definition.go +++ b/tools/build_langserver/lsp/definition.go @@ -18,20 +18,20 @@ func (h *Handler) definition(params *lsp.TextDocumentPositionParams) ([]lsp.Loca ast := h.parseIfNeeded(doc) f := doc.AspFile() - var locs []lsp.Location + locs := []lsp.Location{} pos := aspPos(params.Position) asp.WalkAST(ast, func(expr *asp.Expression) bool { - if !asp.WithinRange(pos, f.Pos(expr.Pos), f.Pos(expr.EndPos)) { + exprStart := f.Pos(expr.Pos) + exprEnd := f.Pos(expr.EndPos) + if !asp.WithinRange(pos, exprStart, exprEnd) { return false } - if expr.Val.Ident != nil { if loc := h.findGlobal(expr.Val.Ident.Name); loc.URI != "" { locs = append(locs, loc) } return false } - if expr.Val.String != "" { label := astutils.TrimStrLit(expr.Val.String) if loc := h.findLabel(doc.PkgName, label); loc.URI != "" { @@ -39,20 +39,19 @@ func (h *Handler) definition(params *lsp.TextDocumentPositionParams) ([]lsp.Loca } return false } - return true }) - // It might also be a statement. + // It might also be a statement (e.g. a function call like go_library(...)) asp.WalkAST(ast, func(stmt *asp.Statement) bool { if stmt.Ident != nil { - endPos := f.Pos(stmt.Pos) + stmtStart := f.Pos(stmt.Pos) + endPos := stmtStart // TODO(jpoole): The AST should probably just have this information endPos.Column += len(stmt.Ident.Name) - if !asp.WithinRange(pos, f.Pos(stmt.Pos), endPos) { - return false + if !asp.WithinRange(pos, stmtStart, endPos) { + return true // continue to other statements } - if loc := h.findGlobal(stmt.Ident.Name); loc.URI != "" { locs = append(locs, loc) } @@ -78,6 +77,9 @@ func (h *Handler) findLabel(currentPath, label string) lsp.Location { } pkg := h.state.Graph.PackageByLabel(l) + if pkg == nil { + return lsp.Location{} + } uri := lsp.DocumentURI("file://" + filepath.Join(h.root, pkg.Filename)) loc := lsp.Location{URI: uri} doc, err := h.maybeOpenDoc(uri) @@ -137,9 +139,17 @@ func findName(args []asp.CallArgument) string { // findGlobal returns the location of a global of the given name. func (h *Handler) findGlobal(name string) lsp.Location { - if f, present := h.builtins[name]; present { + h.mutex.Lock() + f, present := h.builtins[name] + h.mutex.Unlock() + if present { + filename := f.Pos.Filename + // Make path absolute if it's relative + if !filepath.IsAbs(filename) { + filename = filepath.Join(h.root, filename) + } return lsp.Location{ - URI: lsp.DocumentURI("file://" + f.Pos.Filename), + URI: lsp.DocumentURI("file://" + filename), Range: rng(f.Pos, f.EndPos), } } diff --git a/tools/build_langserver/lsp/lsp.go b/tools/build_langserver/lsp/lsp.go index b979360177..0662dc529c 100644 --- a/tools/build_langserver/lsp/lsp.go +++ b/tools/build_langserver/lsp/lsp.go @@ -20,6 +20,7 @@ import ( "github.com/thought-machine/please/rules" "github.com/thought-machine/please/src/core" "github.com/thought-machine/please/src/fs" + "github.com/thought-machine/please/src/parse" "github.com/thought-machine/please/src/parse/asp" "github.com/thought-machine/please/src/plz" ) @@ -195,16 +196,38 @@ func (h *Handler) initialize(params *lsp.InitializeParams) (*lsp.InitializeResul } h.state = core.NewBuildState(config) h.state.NeedBuild = false - // We need an unwrapped parser instance as well for raw access. - h.parser = asp.NewParser(h.state) + // Initialize the parser on state first, so that plz.RunHost uses the same parser. + // This ensures plugin subincludes are stored in the same AST cache we use. + parse.InitParser(h.state) + h.parser = parse.GetAspParser(h.state) + if h.parser == nil { + return nil, fmt.Errorf("failed to get asp parser from state") + } // Parse everything in the repo up front. // This is a lot easier than trying to do clever partial parses later on, although // eventually we may want that if we start dealing with truly large repos. go func() { + // Start a goroutine to periodically load parser functions as they become available. + // This allows go-to-definition to work progressively while the full parse runs. + done := make(chan struct{}) + go func() { + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + for { + select { + case <-done: + return + case <-ticker.C: + h.loadParserFunctions() + } + } + }() plz.RunHost(core.WholeGraph, h.state) + close(done) log.Debug("initial parse complete") h.buildPackageTree() log.Debug("built completion package tree") + h.loadParserFunctions() }() // Record all the builtin functions now if err := h.loadBuiltins(); err != nil { @@ -268,6 +291,37 @@ func (h *Handler) loadBuiltins() error { return nil } +// loadParserFunctions loads function definitions from the parser's ASTs. +// This includes plugin-defined functions like go_library, python_library, etc. +func (h *Handler) loadParserFunctions() { + funcsByFile := h.parser.AllFunctionsByFile() + if funcsByFile == nil { + return + } + h.mutex.Lock() + defer h.mutex.Unlock() + for filename, stmts := range funcsByFile { + // Read the file to create a File object for position conversion + data, err := os.ReadFile(filename) + if err != nil { + log.Warning("failed to read file %s: %v", filename, err) + continue + } + file := asp.NewFile(filename, data) + for _, stmt := range stmts { + name := stmt.FuncDef.Name + // Only add if not already present (don't override core builtins) + if _, present := h.builtins[name]; !present { + h.builtins[name] = builtin{ + Stmt: stmt, + Pos: file.Pos(stmt.Pos), + EndPos: file.Pos(stmt.EndPos), + } + } + } + } +} + // fromURI converts a DocumentURI to a path. func fromURI(uri lsp.DocumentURI) string { if !strings.HasPrefix(string(uri), "file://") { From 0dd04538274e382b85dfecaad05325133c7d751c Mon Sep 17 00:00:00 2001 From: Andrzej J Skalski Date: Thu, 5 Feb 2026 17:22:00 +0100 Subject: [PATCH 2/2] feat(langserver): add find-all-references support Implements textDocument/references for the BUILD file language server. Supports two modes: 1. Function references: When cursor is on a function definition (e.g., `def go_repo(...)`), finds all BUILD files that call that function. 2. Build label references: When cursor is on a build label, uses query.FindRevdeps to find all targets that depend on it, then locates the exact string literal positions in their BUILD files. --- tools/build_langserver/lsp/BUILD | 2 + tools/build_langserver/lsp/lsp.go | 7 + tools/build_langserver/lsp/references.go | 214 +++++++++++++++++++++++ 3 files changed, 223 insertions(+) create mode 100644 tools/build_langserver/lsp/references.go diff --git a/tools/build_langserver/lsp/BUILD b/tools/build_langserver/lsp/BUILD index d0be35f006..3498e5aef0 100644 --- a/tools/build_langserver/lsp/BUILD +++ b/tools/build_langserver/lsp/BUILD @@ -5,6 +5,7 @@ go_library( "definition.go", "diagnostics.go", "lsp.go", + "references.go", "symbols.go", "text.go", ], @@ -20,6 +21,7 @@ go_library( "//src/parse", "//src/parse/asp", "//src/plz", + "//src/query", "//tools/build_langserver/lsp/astutils", ], ) diff --git a/tools/build_langserver/lsp/lsp.go b/tools/build_langserver/lsp/lsp.go index 0662dc529c..558ad68138 100644 --- a/tools/build_langserver/lsp/lsp.go +++ b/tools/build_langserver/lsp/lsp.go @@ -174,6 +174,12 @@ func (h *Handler) handle(method string, params *json.RawMessage) (res interface{ return nil, &jsonrpc2.Error{Code: jsonrpc2.CodeInvalidParams} } return h.definition(positionParams) + case "textDocument/references": + referenceParams := &lsp.ReferenceParams{} + if err := json.Unmarshal(*params, referenceParams); err != nil { + return nil, &jsonrpc2.Error{Code: jsonrpc2.CodeInvalidParams} + } + return h.references(referenceParams) default: return nil, &jsonrpc2.Error{Code: jsonrpc2.CodeMethodNotFound} } @@ -244,6 +250,7 @@ func (h *Handler) initialize(params *lsp.InitializeParams) (*lsp.InitializeResul DocumentFormattingProvider: true, DocumentSymbolProvider: true, DefinitionProvider: true, + ReferencesProvider: true, CompletionProvider: &lsp.CompletionOptions{ TriggerCharacters: []string{"/", ":"}, }, diff --git a/tools/build_langserver/lsp/references.go b/tools/build_langserver/lsp/references.go new file mode 100644 index 0000000000..ca52986ea5 --- /dev/null +++ b/tools/build_langserver/lsp/references.go @@ -0,0 +1,214 @@ +package lsp + +import ( + "path/filepath" + + "github.com/sourcegraph/go-lsp" + + "github.com/thought-machine/please/src/core" + "github.com/thought-machine/please/src/parse/asp" + "github.com/thought-machine/please/src/query" + "github.com/thought-machine/please/tools/build_langserver/lsp/astutils" +) + +// references implements 'find all references' support. +func (h *Handler) references(params *lsp.ReferenceParams) ([]lsp.Location, error) { + doc := h.doc(params.TextDocument.URI) + ast := h.parseIfNeeded(doc) + f := doc.AspFile() + pos := aspPos(params.Position) + + // Check if cursor is on a function definition (def funcname(...)) + var funcName string + asp.WalkAST(ast, func(stmt *asp.Statement) bool { + if stmt.FuncDef != nil { + stmtStart := f.Pos(stmt.Pos) + // Check if cursor is on the function name + nameEnd := stmtStart + nameEnd.Column += len("def ") + len(stmt.FuncDef.Name) + if asp.WithinRange(pos, stmtStart, nameEnd) { + funcName = stmt.FuncDef.Name + return false + } + } + return true + }) + + // If we found a function definition, find all calls to it + if funcName != "" { + return h.findFunctionReferences(funcName, params.Context.IncludeDeclaration) + } + + // Otherwise, look for build label references + return h.findLabelReferences(doc, ast, f, pos, params.Context.IncludeDeclaration) +} + +// findFunctionReferences finds all calls to a function across all BUILD files. +func (h *Handler) findFunctionReferences(funcName string, includeDeclaration bool) ([]lsp.Location, error) { + locs := []lsp.Location{} + + // Search all packages for calls to this function + for _, pkg := range h.state.Graph.PackageMap() { + uri := lsp.DocumentURI("file://" + filepath.Join(h.root, pkg.Filename)) + refDoc, err := h.maybeOpenDoc(uri) + if err != nil { + continue + } + refAst := h.parseIfNeeded(refDoc) + refFile := refDoc.AspFile() + + // Find all statement calls to the function (e.g., go_library(...)) + asp.WalkAST(refAst, func(stmt *asp.Statement) bool { + if stmt.Ident != nil && stmt.Ident.Name == funcName { + start := refFile.Pos(stmt.Pos) + end := start + end.Column += len(funcName) + locs = append(locs, lsp.Location{ + URI: uri, + Range: lsp.Range{ + Start: lsp.Position{Line: start.Line - 1, Character: start.Column - 1}, + End: lsp.Position{Line: end.Line - 1, Character: end.Column - 1}, + }, + }) + } + return true + }) + + // Find expression calls (e.g., x = go_library(...)) + asp.WalkAST(refAst, func(expr *asp.Expression) bool { + if expr.Val.Ident != nil && expr.Val.Ident.Name == funcName && len(expr.Val.Ident.Action) > 0 && expr.Val.Ident.Action[0].Call != nil { + start := refFile.Pos(expr.Pos) + end := start + end.Column += len(funcName) + locs = append(locs, lsp.Location{ + URI: uri, + Range: lsp.Range{ + Start: lsp.Position{Line: start.Line - 1, Character: start.Column - 1}, + End: lsp.Position{Line: end.Line - 1, Character: end.Column - 1}, + }, + }) + } + return true + }) + } + + // Include the definition itself if requested + if includeDeclaration { + h.mutex.Lock() + if builtin, ok := h.builtins[funcName]; ok { + filename := builtin.Pos.Filename + if !filepath.IsAbs(filename) { + filename = filepath.Join(h.root, filename) + } + locs = append(locs, lsp.Location{ + URI: lsp.DocumentURI("file://" + filename), + Range: rng(builtin.Pos, builtin.EndPos), + }) + } + h.mutex.Unlock() + } + + return locs, nil +} + +// findLabelReferences finds all references to a build label. +func (h *Handler) findLabelReferences(doc *doc, ast []*asp.Statement, f *asp.File, pos asp.FilePosition, includeDeclaration bool) ([]lsp.Location, error) { + var targetLabel core.BuildLabel + var targetName string + + // Check if cursor is on a string (build label) + asp.WalkAST(ast, func(expr *asp.Expression) bool { + exprStart := f.Pos(expr.Pos) + exprEnd := f.Pos(expr.EndPos) + if !asp.WithinRange(pos, exprStart, exprEnd) { + return false + } + if expr.Val.String != "" { + label := astutils.TrimStrLit(expr.Val.String) + if l, err := core.TryParseBuildLabel(label, doc.PkgName, ""); err == nil { + targetLabel = l + } + return false + } + return true + }) + + // Check if cursor is on a target definition (name = "...") + if targetLabel.IsEmpty() { + asp.WalkAST(ast, func(stmt *asp.Statement) bool { + if stmt.Ident != nil && stmt.Ident.Action != nil && stmt.Ident.Action.Call != nil { + stmtStart := f.Pos(stmt.Pos) + stmtEnd := f.Pos(stmt.EndPos) + if asp.WithinRange(pos, stmtStart, stmtEnd) { + if name := findName(stmt.Ident.Action.Call.Arguments); name != "" { + targetLabel = core.BuildLabel{PackageName: doc.PkgName, Name: name} + targetName = name + } + } + return false + } + return true + }) + } + + if targetLabel.IsEmpty() { + return []lsp.Location{}, nil + } + + // Use query.FindRevdeps to find all reverse dependencies + // Parameters: hidden=false, followSubincludes=true, includeSubrepos=true, depth=-1 (unlimited) + revdeps := query.FindRevdeps(h.state, core.BuildLabels{targetLabel}, false, true, true, -1) + + locs := []lsp.Location{} + + // For each reverse dependency, find the exact location of the reference in its BUILD file + for target := range revdeps { + pkg := h.state.Graph.PackageByLabel(target.Label) + if pkg == nil { + continue + } + + uri := lsp.DocumentURI("file://" + filepath.Join(h.root, pkg.Filename)) + refDoc, err := h.maybeOpenDoc(uri) + if err != nil { + continue + } + refAst := h.parseIfNeeded(refDoc) + refFile := refDoc.AspFile() + + // Find all string literals that reference our target + labelStr := targetLabel.String() + shortLabelStr := ":" + targetLabel.Name // For same-package references + + asp.WalkAST(refAst, func(expr *asp.Expression) bool { + if expr.Val.String != "" { + str := astutils.TrimStrLit(expr.Val.String) + // Check if this string matches our target label + if str == labelStr || (refDoc.PkgName == targetLabel.PackageName && str == shortLabelStr) { + // Also try parsing it as a label to handle relative references + if l, err := core.TryParseBuildLabel(str, refDoc.PkgName, ""); err == nil && l == targetLabel { + start := refFile.Pos(expr.Pos) + end := refFile.Pos(expr.EndPos) + locs = append(locs, lsp.Location{ + URI: uri, + Range: lsp.Range{ + Start: lsp.Position{Line: start.Line - 1, Character: start.Column - 1}, + End: lsp.Position{Line: end.Line - 1, Character: end.Column - 1}, + }, + }) + } + } + } + return true + }) + } + + // Optionally include the definition itself if requested + if includeDeclaration && targetName != "" { + if defLoc := h.findLabel(doc.PkgName, targetLabel.String()); defLoc.URI != "" { + locs = append(locs, defLoc) + } + } + + return locs, nil +}