diff --git a/rewrite-go/rewrite/pkg/parser/go_parser.go b/rewrite-go/rewrite/pkg/parser/go_parser.go index ab3c782c5a..8b9264c551 100644 --- a/rewrite-go/rewrite/pkg/parser/go_parser.go +++ b/rewrite-go/rewrite/pkg/parser/go_parser.go @@ -736,19 +736,39 @@ func (ctx *parseContext) mapFieldListAsParams(fl *ast.FieldList) tree.Container[ } } - closeParen := ctx.prefix(fl.Closing) - ctx.skip(1) // ")" - + var markers tree.Markers if len(elements) > 0 { - elements[len(elements)-1].After = closeParen - } else if len(closeParen.Comments) > 0 { - elements = append(elements, tree.RightPadded[tree.Statement]{ - Element: &tree.Empty{ID: uuid.New()}, - After: closeParen, - }) + trailingCommaOff := ctx.findNextBefore(',', int(fl.Closing)-ctx.file.Base()) + if trailingCommaOff >= 0 { + commaBefore := ctx.prefix(ctx.file.Pos(trailingCommaOff)) + ctx.skip(1) // "," + commaAfter := ctx.prefix(fl.Closing) + ctx.skip(1) // ")" + markers = tree.Markers{ + ID: uuid.New(), + Entries: []tree.Marker{tree.TrailingComma{ + Ident: uuid.New(), + Before: commaBefore, + After: commaAfter, + }}, + } + } else { + closePrefix := ctx.prefix(fl.Closing) + ctx.skip(1) // ")" + elements[len(elements)-1].After = closePrefix + } + } else { + closeParen := ctx.prefix(fl.Closing) + ctx.skip(1) // ")" + if len(closeParen.Comments) > 0 { + elements = append(elements, tree.RightPadded[tree.Statement]{ + Element: &tree.Empty{ID: uuid.New()}, + After: closeParen, + }) + } } - return tree.Container[tree.Statement]{Before: before, Elements: elements} + return tree.Container[tree.Statement]{Before: before, Elements: elements, Markers: markers} } // mapBlockStmt maps a block statement. diff --git a/rewrite-go/rewrite/pkg/printer/go_printer.go b/rewrite-go/rewrite/pkg/printer/go_printer.go index 7b6081c345..3f2c347a95 100644 --- a/rewrite-go/rewrite/pkg/printer/go_printer.go +++ b/rewrite-go/rewrite/pkg/printer/go_printer.go @@ -230,11 +230,16 @@ func (p *GoPrinter) VisitMethodDeclaration(md *tree.MethodDeclaration, param any func (p *GoPrinter) printParamList(params tree.Container[tree.Statement], out *PrintOutputCapture) { p.visitSpace(params.Before, out) out.Append("(") + tc := tree.FindMarker[tree.TrailingComma](params.Markers) for i, rp := range params.Elements { p.Visit(rp.Element, out) if i < len(params.Elements)-1 { p.visitSpace(rp.After, out) out.Append(",") + } else if tc != nil { + p.visitSpace(tc.Before, out) + out.Append(",") + p.visitSpace(tc.After, out) } else { p.visitSpace(rp.After, out) } @@ -792,16 +797,7 @@ func (p *GoPrinter) VisitFuncType(ft *tree.FuncType, param any) tree.J { out := param.(*PrintOutputCapture) p.beforeSyntax(ft.Prefix, ft.Markers, out) out.Append("func") - p.visitSpace(ft.Parameters.Before, out) - out.Append("(") - for i, rp := range ft.Parameters.Elements { - p.Visit(rp.Element, out) - if i < len(ft.Parameters.Elements)-1 { - p.visitSpace(rp.After, out) - out.Append(",") - } - } - out.Append(")") + p.printParamList(ft.Parameters, out) if ft.ReturnType != nil { p.Visit(ft.ReturnType, out) } diff --git a/rewrite-go/rewrite/test/trailing_comma_test.go b/rewrite-go/rewrite/test/trailing_comma_test.go index fe253bc7ac..4d3dea9276 100644 --- a/rewrite-go/rewrite/test/trailing_comma_test.go +++ b/rewrite-go/rewrite/test/trailing_comma_test.go @@ -149,6 +149,18 @@ func TestParseTrailingCommaAnonymousStruct(t *testing.T) { `)) } +func TestParseTrailingCommaFuncTypeParams(t *testing.T) { + NewRecipeSpec().RewriteRun(t, + Golang(` + package main + + var f = map[string]func( + a int, + b int, + ) error{} + `)) +} + func TestParseTrailingCommaMapOfSlices(t *testing.T) { NewRecipeSpec().RewriteRun(t, Golang(`