Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/cmap/cerrmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,15 @@ func (m *ErrMap[K, V]) GetOrSet(key K, f func() (V, error)) (V, error) {
}
return v.Val, v.Err
}

// Range calls f for each key-value pair in the map.
// If f returns false, iteration stops.
// No particular consistency guarantees are made during iteration.
func (m *ErrMap[K, V]) Range(f func(key K, val V) bool) {
m.m.Range(func(key K, val errV[V]) bool {
if val.Err != nil {
return true // skip errors
}
return f(key, val.Val)
})
}
26 changes: 26 additions & 0 deletions src/cmap/cmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,17 @@ func (m *Map[K, V]) Values() []V {
return ret
}

// Range calls f for each key-value pair in the map.
// If f returns false, iteration stops.
// No particular consistency guarantees are made during iteration.
func (m *Map[K, V]) Range(f func(key K, val V) bool) {
for i := 0; i < len(m.shards); i++ {
if !m.shards[i].Range(f) {
return
}
}
}

// An awaitableValue represents a value in the map & an awaitable channel for it to exist.
type awaitableValue[V any] struct {
Val V
Expand Down Expand Up @@ -195,3 +206,18 @@ func (s *shard[K, V]) Contains(key K) bool {
_, ok := s.m[key]
return ok
}

// Range calls f for each key-value pair in this shard.
// Returns false if iteration was stopped early.
func (s *shard[K, V]) Range(f func(key K, val V) bool) bool {
s.l.RLock()
defer s.l.RUnlock()
for k, v := range s.m {
if v.Wait == nil { // Only include completed values
if !f(k, v.Val) {
return false
}
}
}
return true
}
19 changes: 19 additions & 0 deletions src/parse/asp/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,25 @@ func (p *Parser) optimiseBuiltinCalls(stmts []*Statement) {
}
}

// AllFunctionsByFile returns all function definitions grouped by filename.
// This includes functions from builtins, plugins, and subincludes.
// It iterates over the ASTs stored by the interpreter.
func (p *Parser) AllFunctionsByFile() map[string][]*Statement {
if p.interpreter == nil || p.interpreter.asts == nil {
return nil
}
result := make(map[string][]*Statement)
p.interpreter.asts.Range(func(filename string, stmts []*Statement) bool {
for _, stmt := range stmts {
if stmt.FuncDef != nil {
result[filename] = append(result[filename], stmt)
}
}
return true
})
return result
}

// whitelistedKwargs returns true if the given built-in function name is allowed to
// be called as non-kwargs.
// TODO(peterebden): Come up with a syntax that exposes this directly in the file.
Expand Down
13 changes: 13 additions & 0 deletions src/parse/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,19 @@ func InitParser(state *core.BuildState) *core.BuildState {
return state
}

// GetAspParser returns the underlying asp.Parser from the state's parser.
// This is useful for tools like the language server that need direct access to AST information.
// Returns nil if the state's parser is not set or is not an aspParser.
func GetAspParser(state *core.BuildState) *asp.Parser {
if state.Parser == nil {
return nil
}
if ap, ok := state.Parser.(*aspParser); ok {
return ap.parser
}
return nil
}

// aspParser implements the core.Parser interface around our parser package.
type aspParser struct {
parser *asp.Parser
Expand Down
3 changes: 3 additions & 0 deletions tools/build_langserver/lsp/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ go_library(
"definition.go",
"diagnostics.go",
"lsp.go",
"references.go",
"symbols.go",
"text.go",
],
Expand All @@ -17,8 +18,10 @@ go_library(
"//rules",
"//src/core",
"//src/fs",
"//src/parse",
"//src/parse/asp",
"//src/plz",
"//src/query",
"//tools/build_langserver/lsp/astutils",
],
)
Expand Down
34 changes: 22 additions & 12 deletions tools/build_langserver/lsp/definition.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,41 +18,40 @@ func (h *Handler) definition(params *lsp.TextDocumentPositionParams) ([]lsp.Loca
ast := h.parseIfNeeded(doc)
f := doc.AspFile()

var locs []lsp.Location
locs := []lsp.Location{}
pos := aspPos(params.Position)
asp.WalkAST(ast, func(expr *asp.Expression) bool {
if !asp.WithinRange(pos, f.Pos(expr.Pos), f.Pos(expr.EndPos)) {
exprStart := f.Pos(expr.Pos)
exprEnd := f.Pos(expr.EndPos)
if !asp.WithinRange(pos, exprStart, exprEnd) {
return false
}

if expr.Val.Ident != nil {
if loc := h.findGlobal(expr.Val.Ident.Name); loc.URI != "" {
locs = append(locs, loc)
}
return false
}

if expr.Val.String != "" {
label := astutils.TrimStrLit(expr.Val.String)
if loc := h.findLabel(doc.PkgName, label); loc.URI != "" {
locs = append(locs, loc)
}
return false
}

return true
})
// It might also be a statement.
// It might also be a statement (e.g. a function call like go_library(...))
asp.WalkAST(ast, func(stmt *asp.Statement) bool {
if stmt.Ident != nil {
endPos := f.Pos(stmt.Pos)
stmtStart := f.Pos(stmt.Pos)
endPos := stmtStart
// TODO(jpoole): The AST should probably just have this information
endPos.Column += len(stmt.Ident.Name)

if !asp.WithinRange(pos, f.Pos(stmt.Pos), endPos) {
return false
if !asp.WithinRange(pos, stmtStart, endPos) {
return true // continue to other statements
}

if loc := h.findGlobal(stmt.Ident.Name); loc.URI != "" {
locs = append(locs, loc)
}
Expand All @@ -78,6 +77,9 @@ func (h *Handler) findLabel(currentPath, label string) lsp.Location {
}

pkg := h.state.Graph.PackageByLabel(l)
if pkg == nil {
return lsp.Location{}
}
uri := lsp.DocumentURI("file://" + filepath.Join(h.root, pkg.Filename))
loc := lsp.Location{URI: uri}
doc, err := h.maybeOpenDoc(uri)
Expand Down Expand Up @@ -137,9 +139,17 @@ func findName(args []asp.CallArgument) string {

// findGlobal returns the location of a global of the given name.
func (h *Handler) findGlobal(name string) lsp.Location {
if f, present := h.builtins[name]; present {
h.mutex.Lock()
f, present := h.builtins[name]
h.mutex.Unlock()
if present {
filename := f.Pos.Filename
// Make path absolute if it's relative
if !filepath.IsAbs(filename) {
filename = filepath.Join(h.root, filename)
}
return lsp.Location{
URI: lsp.DocumentURI("file://" + f.Pos.Filename),
URI: lsp.DocumentURI("file://" + filename),
Range: rng(f.Pos, f.EndPos),
}
}
Expand Down
65 changes: 63 additions & 2 deletions tools/build_langserver/lsp/lsp.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/thought-machine/please/rules"
"github.com/thought-machine/please/src/core"
"github.com/thought-machine/please/src/fs"
"github.com/thought-machine/please/src/parse"
"github.com/thought-machine/please/src/parse/asp"
"github.com/thought-machine/please/src/plz"
)
Expand Down Expand Up @@ -173,6 +174,12 @@ func (h *Handler) handle(method string, params *json.RawMessage) (res interface{
return nil, &jsonrpc2.Error{Code: jsonrpc2.CodeInvalidParams}
}
return h.definition(positionParams)
case "textDocument/references":
referenceParams := &lsp.ReferenceParams{}
if err := json.Unmarshal(*params, referenceParams); err != nil {
return nil, &jsonrpc2.Error{Code: jsonrpc2.CodeInvalidParams}
}
return h.references(referenceParams)
default:
return nil, &jsonrpc2.Error{Code: jsonrpc2.CodeMethodNotFound}
}
Expand All @@ -195,16 +202,38 @@ func (h *Handler) initialize(params *lsp.InitializeParams) (*lsp.InitializeResul
}
h.state = core.NewBuildState(config)
h.state.NeedBuild = false
// We need an unwrapped parser instance as well for raw access.
h.parser = asp.NewParser(h.state)
// Initialize the parser on state first, so that plz.RunHost uses the same parser.
// This ensures plugin subincludes are stored in the same AST cache we use.
parse.InitParser(h.state)
h.parser = parse.GetAspParser(h.state)
if h.parser == nil {
return nil, fmt.Errorf("failed to get asp parser from state")
}
// Parse everything in the repo up front.
// This is a lot easier than trying to do clever partial parses later on, although
// eventually we may want that if we start dealing with truly large repos.
go func() {
// Start a goroutine to periodically load parser functions as they become available.
// This allows go-to-definition to work progressively while the full parse runs.
done := make(chan struct{})
go func() {
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
for {
select {
case <-done:
return
case <-ticker.C:
h.loadParserFunctions()
}
}
}()
plz.RunHost(core.WholeGraph, h.state)
close(done)
log.Debug("initial parse complete")
h.buildPackageTree()
log.Debug("built completion package tree")
h.loadParserFunctions()
}()
// Record all the builtin functions now
if err := h.loadBuiltins(); err != nil {
Expand All @@ -221,6 +250,7 @@ func (h *Handler) initialize(params *lsp.InitializeParams) (*lsp.InitializeResul
DocumentFormattingProvider: true,
DocumentSymbolProvider: true,
DefinitionProvider: true,
ReferencesProvider: true,
CompletionProvider: &lsp.CompletionOptions{
TriggerCharacters: []string{"/", ":"},
},
Expand Down Expand Up @@ -268,6 +298,37 @@ func (h *Handler) loadBuiltins() error {
return nil
}

// loadParserFunctions loads function definitions from the parser's ASTs.
// This includes plugin-defined functions like go_library, python_library, etc.
func (h *Handler) loadParserFunctions() {
funcsByFile := h.parser.AllFunctionsByFile()
if funcsByFile == nil {
return
}
h.mutex.Lock()
defer h.mutex.Unlock()
for filename, stmts := range funcsByFile {
// Read the file to create a File object for position conversion
data, err := os.ReadFile(filename)
if err != nil {
log.Warning("failed to read file %s: %v", filename, err)
continue
}
file := asp.NewFile(filename, data)
for _, stmt := range stmts {
name := stmt.FuncDef.Name
// Only add if not already present (don't override core builtins)
if _, present := h.builtins[name]; !present {
h.builtins[name] = builtin{
Stmt: stmt,
Pos: file.Pos(stmt.Pos),
EndPos: file.Pos(stmt.EndPos),
}
}
}
}
}

// fromURI converts a DocumentURI to a path.
func fromURI(uri lsp.DocumentURI) string {
if !strings.HasPrefix(string(uri), "file://") {
Expand Down
Loading
Loading