diff --git a/cmd/jzero/internal/command/gen/genapi/genapi.go b/cmd/jzero/internal/command/gen/genapi/genapi.go index de49066bc..8cf9dd53c 100644 --- a/cmd/jzero/internal/command/gen/genapi/genapi.go +++ b/cmd/jzero/internal/command/gen/genapi/genapi.go @@ -183,121 +183,229 @@ func (ja *JzeroApi) Gen() (map[string]*spec.ApiSpec, error) { } if !config.C.Quiet { - fmt.Println(console.Green("Done")) + fmt.Println(console.Green("Gen Api Done")) } return apiSpecMap, nil } func (ja *JzeroApi) generateApiCode(apiFiles []string, apiSpecMap map[string]*spec.ApiSpec, genCodeApiFiles []string, genCodeApiSpecMap map[string]*spec.ApiSpec, currentRoutesMap map[string][]spec.Route, importedFiles map[string]bool) error { + if err := ja.cleanHandlersDir(genCodeApiFiles, genCodeApiSpecMap); err != nil { + return err + } + + templateDir, err := ja.prepareTemplateDir() + if err != nil { + return err + } + defer os.RemoveAll(templateDir) + + allRoutesGoBody, err := ja.collectRoutesGoBody(apiFiles, apiSpecMap, currentRoutesMap, importedFiles) + if err != nil { + return err + } + + if err := ja.generateCodeForApiFiles(genCodeApiFiles, apiSpecMap, importedFiles, templateDir); err != nil { + return err + } + + if err := ja.patchHandlerAndLogicFiles(genCodeApiFiles, apiSpecMap, genCodeApiSpecMap); err != nil { + return err + } + + if err := ja.generateRoutesGoFile(apiFiles, apiSpecMap, importedFiles, allRoutesGoBody); err != nil { + return err + } + + if config.C.Gen.Route2Code { + if err := ja.generateRoute2CodeFile(apiSpecMap, currentRoutesMap, importedFiles); err != nil { + return err + } + } + + return nil +} + +// cleanHandlersDir 清理 handler 目录下的旧文件 +func (ja *JzeroApi) cleanHandlersDir(genCodeApiFiles []string, genCodeApiSpecMap map[string]*spec.ApiSpec) error { + var eg errgroup.Group for _, file := range genCodeApiFiles { - if parse, ok := genCodeApiSpecMap[file]; ok { - for _, group := range parse.Service.Groups { - dirFile, err := os.ReadDir(filepath.Join(config.C.Wd(), "internal", "handler", group.GetAnnotation("group"))) - if err == nil { - for _, v := range dirFile { - if !v.IsDir() { - _ = os.Remove(filepath.Join(config.C.Wd(), "internal", "handler", group.GetAnnotation("group"), v.Name())) - } + apiSpec, ok := genCodeApiSpecMap[file] + if !ok { + continue + } + + for _, group := range apiSpec.Service.Groups { + groupAnnotation := group.GetAnnotation("group") + if groupAnnotation == "" { + continue + } + + handlerDir := filepath.Join(config.C.Wd(), "internal", "handler", groupAnnotation) + eg.Go(func() error { + dirEntries, err := os.ReadDir(handlerDir) + if err != nil { + return nil // 目录不存在或无法读取,忽略错误 + } + for _, entry := range dirEntries { + if !entry.IsDir() { + _ = os.Remove(filepath.Join(handlerDir, entry.Name())) } } - } + return nil + }) } } + return eg.Wait() +} - // 处理模板 - var goctlHome string +// prepareTemplateDir 准备模板目录,返回临时目录路径 +func (ja *JzeroApi) prepareTemplateDir() (string, error) { tempDir, err := os.MkdirTemp(os.TempDir(), "") if err != nil { - return err + return "", err } - defer os.RemoveAll(tempDir) - // 先写入内置模板 - err = embeded.WriteTemplateDir(filepath.Join("go-zero", "api"), filepath.Join(tempDir, "api")) - if err != nil { - return err + // 写入内置模板 + if err := embeded.WriteTemplateDir(filepath.Join("go-zero", "api"), filepath.Join(tempDir, "api")); err != nil { + _ = os.RemoveAll(tempDir) + return "", err } // 如果用户自定义了模板,则复制覆盖 customTemplatePath := filepath.Join(config.C.Home, "go-zero", "api") if pathx.FileExists(customTemplatePath) { - err = filex.CopyDir(customTemplatePath, filepath.Join(tempDir, "api")) - if err != nil { - return err + if err := filex.CopyDir(customTemplatePath, filepath.Join(tempDir, "api")); err != nil { + _ = os.RemoveAll(tempDir) + return "", err } } - goctlHome = tempDir - logx.Debugf("goctl_home = %s", goctlHome) + logx.Debugf("goctl_home = %s", tempDir) + return tempDir, nil +} - var handlerImports ImportLines - var allRoutesGoBody string +// collectRoutesGoBody 并发收集所有文件的 routesGoBody +func (ja *JzeroApi) collectRoutesGoBody(apiFiles []string, apiSpecMap map[string]*spec.ApiSpec, currentRoutesMap map[string][]spec.Route, importedFiles map[string]bool) (string, error) { var allRoutesGoBodyMap sync.Map var eg errgroup.Group eg.SetLimit(len(apiFiles)) - for _, v := range apiFiles { - // 跳过被 import 的文件 - if importedFiles[v] { + + for _, apiFile := range apiFiles { + if importedFiles[apiFile] { continue } - cv := v + currentFile := apiFile eg.Go(func() error { - routesGoBody, err := ja.getRoutesGoBody(cv, apiSpecMap, currentRoutesMap) + routesGoBody, err := ja.getRoutesGoBody(currentFile, apiSpecMap, currentRoutesMap) if err != nil { return err } if routesGoBody != "" { - allRoutesGoBodyMap.Store(cv, routesGoBody) + allRoutesGoBodyMap.Store(currentFile, routesGoBody) } return nil }) } if err := eg.Wait(); err != nil { - return err + return "", err } - for _, v := range apiFiles { - if s, ok := allRoutesGoBodyMap.Load(v); ok { - allRoutesGoBody += cast.ToString(s) + "\n" + var allRoutesGoBody strings.Builder + for _, apiFile := range apiFiles { + if body, ok := allRoutesGoBodyMap.Load(apiFile); ok { + allRoutesGoBody.WriteString(cast.ToString(body)) + allRoutesGoBody.WriteString("\n") } } - for _, v := range genCodeApiFiles { - // 跳过被 import 的文件 - if importedFiles[v] { + return allRoutesGoBody.String(), nil +} + +// generateCodeForApiFiles 为所有 API 文件生成代码 +func (ja *JzeroApi) generateCodeForApiFiles(genCodeApiFiles []string, apiSpecMap map[string]*spec.ApiSpec, importedFiles map[string]bool, templateDir string) error { + // 按 group 分组,同一 group 的文件串行处理,不同 group 并发处理 + groupToFiles := make(map[string][]string) + for _, apiFile := range genCodeApiFiles { + if importedFiles[apiFile] { + continue + } + if len(apiSpecMap[apiFile].Service.Routes()) == 0 { continue } - if len(apiSpecMap[v].Service.Routes()) > 0 { - logicFiles, err := ja.getAllLogicFiles(v, apiSpecMap[v]) - if err != nil { - return err + // 收集该文件的所有 group + groups := make(map[string]struct{}) + for _, g := range apiSpecMap[apiFile].Service.Groups { + if groupAnnotation := g.GetAnnotation("group"); groupAnnotation != "" { + groups[groupAnnotation] = struct{}{} } + } - handlerFiles, err := ja.getAllHandlerFiles(v, apiSpecMap[v]) - if err != nil { - return err + // 如果没有 group,使用默认分组 + if len(groups) == 0 { + groupToFiles[""] = append(groupToFiles[""], apiFile) + } else { + for group := range groups { + groupToFiles[group] = append(groupToFiles[group], apiFile) } + } + } + + // 并发处理不同 group + var eg errgroup.Group + for _, files := range groupToFiles { + currentFiles := files + eg.Go(func() error { + // 同一 group 内的文件串行处理 + for _, apiFile := range currentFiles { + if !config.C.Quiet { + fmt.Printf("%s api file %s \n", console.Green("Using"), apiFile) + } - dir := "." - if !config.C.Quiet { - fmt.Printf("%s api file %s\n", console.Green("Using"), v) + if err := format.ApiFormatByPath(apiFile, false); err != nil { + return errors.Wrapf(err, "format api file: %s", apiFile) + } + + command := fmt.Sprintf("goctl api go --api %s --dir %s --home %s --style %s", apiFile, ".", templateDir, config.C.Style) + logx.Debugf("command: %s", command) + + if _, err := execx.Run(command, config.C.Wd()); err != nil { + return errors.Wrapf(err, "api file: %s", apiFile) + } } + return nil + }) + } - if err = format.ApiFormatByPath(v, false); err != nil { - return errors.Wrapf(err, "format api file: %s", v) + return eg.Wait() +} + +// patchHandlerAndLogicFiles 并发 patch handler 和 logic 文件 +func (ja *JzeroApi) patchHandlerAndLogicFiles(genCodeApiFiles []string, apiSpecMap map[string]*spec.ApiSpec, genCodeApiSpecMap map[string]*spec.ApiSpec) error { + var eg errgroup.Group + + for _, apiFile := range genCodeApiFiles { + if len(apiSpecMap[apiFile].Service.Routes()) == 0 { + continue + } + + currentFile := apiFile + + eg.Go(func() error { + logicFiles, err := ja.getAllLogicFiles(currentFile, apiSpecMap[currentFile]) + if err != nil { + return err } - command := fmt.Sprintf("goctl api go --api %s --dir %s --home %s --style %s", v, dir, goctlHome, config.C.Style) - logx.Debugf("command: %s", command) - if _, err := execx.Run(command, config.C.Wd()); err != nil { - return errors.Wrapf(err, "api file: %s", v) + handlerFiles, err := ja.getAllHandlerFiles(currentFile, apiSpecMap[currentFile]) + if err != nil { + return err } - // patch handler + // Patch handler files for _, file := range handlerFiles { if _, ok := genCodeApiSpecMap[file.ApiFilepath]; ok { if err = ja.patchHandler(file, genCodeApiSpecMap); err != nil { @@ -305,6 +413,8 @@ func (ja *JzeroApi) generateApiCode(apiFiles []string, apiSpecMap map[string]*sp } } } + + // Patch logic files for _, file := range logicFiles { if _, ok := genCodeApiSpecMap[file.DescFilepath]; ok { if err = ja.patchLogic(file, genCodeApiSpecMap); err != nil { @@ -312,52 +422,72 @@ func (ja *JzeroApi) generateApiCode(apiFiles []string, apiSpecMap map[string]*sp } } } - } + + return nil + }) } - for _, v := range apiFiles { - // 跳过被 import 的文件 - if importedFiles[v] { + return eg.Wait() +} + +// generateRoutesGoFile 生成 routes.go 文件 +func (ja *JzeroApi) generateRoutesGoFile(apiFiles []string, apiSpecMap map[string]*spec.ApiSpec, importedFiles map[string]bool, allRoutesGoBody string) error { + var handlerImports ImportLines + + for _, apiFile := range apiFiles { + if importedFiles[apiFile] { continue } - for _, g := range apiSpecMap[v].Service.Groups { - if g.GetAnnotation("group") != "" { - handlerImports = append(handlerImports, fmt.Sprintf(`%s "%s/internal/handler/%s"`, strings.ReplaceAll(g.GetAnnotation("group"), "/", ""), ja.Module, g.GetAnnotation("group"))) + for _, group := range apiSpecMap[apiFile].Service.Groups { + groupAnnotation := group.GetAnnotation("group") + if groupAnnotation != "" { + importPath := fmt.Sprintf(`%s "%s/internal/handler/%s"`, + strings.ReplaceAll(groupAnnotation, "/", ""), ja.Module, groupAnnotation) + handlerImports = append(handlerImports, importPath) } } } - template, err := templatex.ParseTemplate(filepath.Join("api", "routes.go.tpl"), map[string]any{ - "Routes": allRoutesGoBody, - "Module": ja.Module, - "HandlerImports": lo.Uniq(handlerImports), - }, embeded.ReadTemplateFile(filepath.Join("api", "routes.go.tpl"))) + templateContent, err := templatex.ParseTemplate( + filepath.Join("api", "routes.go.tpl"), + map[string]any{ + "Routes": allRoutesGoBody, + "Module": ja.Module, + "HandlerImports": lo.Uniq(handlerImports), + }, + embeded.ReadTemplateFile(filepath.Join("api", "routes.go.tpl")), + ) if err != nil { return err } - process, err := gosimports.Process("", template, nil) + + process, err := gosimports.Process("", templateContent, nil) if err != nil { return err } - if err = os.WriteFile(filepath.Join("internal", "handler", "routes.go"), process, 0o644); err != nil { + + return os.WriteFile(filepath.Join("internal", "handler", "routes.go"), process, 0o644) +} + +// generateRoute2CodeFile 生成 route2code.go 文件 +func (ja *JzeroApi) generateRoute2CodeFile(apiSpecMap map[string]*spec.ApiSpec, currentRoutesMap map[string][]spec.Route, importedFiles map[string]bool) error { + if !config.C.Quiet { + fmt.Printf("%s to generate internal/handler/route2code.go\n", console.Green("Start")) + } + + route2CodeBytes, err := ja.genRoute2Code(apiSpecMap, currentRoutesMap, importedFiles) + if err != nil { return err } - if config.C.Gen.Route2Code { - if !config.C.Quiet { - fmt.Printf("%s to generate internal/handler/route2code.go\n", console.Green("Start")) - } - if route2CodeBytes, err := ja.genRoute2Code(apiSpecMap, currentRoutesMap, importedFiles); err != nil { - return err - } else { - if err = os.WriteFile(filepath.Join("internal", "handler", "route2code.go"), route2CodeBytes, 0o644); err != nil { - return err - } - } - if !config.C.Quiet { - fmt.Printf("%s", console.Green("Done\n")) - } + if err := os.WriteFile(filepath.Join("internal", "handler", "route2code.go"), route2CodeBytes, 0o644); err != nil { + return err } + + if !config.C.Quiet { + fmt.Printf("%s", console.Green("Gen Route2code Done\n")) + } + return nil } diff --git a/cmd/jzero/internal/command/gen/genapi/types.go b/cmd/jzero/internal/command/gen/genapi/types.go index 54f55669e..154a4e456 100644 --- a/cmd/jzero/internal/command/gen/genapi/types.go +++ b/cmd/jzero/internal/command/gen/genapi/types.go @@ -7,31 +7,111 @@ import ( "os" "path/filepath" "strings" + "sync" "github.com/rinchsan/gosimports" "github.com/samber/lo" "github.com/zeromicro/go-zero/tools/goctl/api/gogen" "github.com/zeromicro/go-zero/tools/goctl/api/spec" "github.com/zeromicro/go-zero/tools/goctl/util" + "golang.org/x/sync/errgroup" "golang.org/x/tools/go/ast/astutil" "github.com/jzero-io/jzero/cmd/jzero/internal/pkg/templatex" ) func (ja *JzeroApi) separateTypesGo(apiFiles []string, apiSpecMap map[string]*spec.ApiSpec) error { - var allTypes []spec.Type + _, typesWithoutPackage, err := ja.collectAndGenerateTypesByPackage(apiFiles, apiSpecMap) + if err != nil { + return err + } - for _, apiFile := range apiFiles { - typesGoString, err := gogen.BuildTypes(apiSpecMap[apiFile].Types) - if err != nil { + // 去重并生成默认 types.go + if len(typesWithoutPackage) > 0 { + if err := ja.generateDefaultTypesFile(typesWithoutPackage); err != nil { return err } - goPackage, ok := apiSpecMap[apiFile].Info.Properties["go_package"] - if ok && goPackage != "" { - typesGoBytes, err := templatex.ParseTemplate("inner_types.go", map[string]any{ - "Types": typesGoString, - "Package": strings.ToLower(strings.ReplaceAll(goPackage, "/", "")), - }, []byte(`// Code generated by jzero. DO NOT EDIT. + } + + return nil +} + +// apiFileTypes 表示 API 文件的类型信息 +type apiFileTypes struct { + file string + types []spec.Type + goPackage string + typesGoBytes []byte +} + +// collectAndGenerateTypesByPackage 收集并按包生成 types.go 文件 +func (ja *JzeroApi) collectAndGenerateTypesByPackage(apiFiles []string, apiSpecMap map[string]*spec.ApiSpec) (typesWithPackage []apiFileTypes, typesWithoutPackage []spec.Type, err error) { + var eg errgroup.Group + var mu sync.Mutex + results := make([]apiFileTypes, 0, len(apiFiles)) + + for _, apiFile := range apiFiles { + currentFile := apiFile + currentSpec := apiSpecMap[apiFile] + + eg.Go(func() error { + fileTypes, err := ja.processApiFileTypes(currentFile, currentSpec) + if err != nil { + return err + } + + mu.Lock() + if fileTypes.goPackage != "" { + results = append(results, fileTypes) + } else { + typesWithoutPackage = append(typesWithoutPackage, fileTypes.types...) + } + mu.Unlock() + + return nil + }) + } + + if err := eg.Wait(); err != nil { + return nil, nil, err + } + + // 并发生成有 go_package 的 types.go 文件 + eg = errgroup.Group{} + for _, fileTypes := range results { + currentTypes := fileTypes + eg.Go(func() error { + return ja.writeTypesFile(currentTypes.goPackage, currentTypes.typesGoBytes) + }) + } + + if err := eg.Wait(); err != nil { + return nil, nil, err + } + + return results, typesWithoutPackage, nil +} + +// processApiFileTypes 处理单个 API 文件的类型 +func (ja *JzeroApi) processApiFileTypes(apiFile string, apiSpec *spec.ApiSpec) (apiFileTypes, error) { + typesGoString, err := gogen.BuildTypes(apiSpec.Types) + if err != nil { + return apiFileTypes{}, err + } + + goPackage, hasGoPackage := apiSpec.Info.Properties["go_package"] + if !hasGoPackage || goPackage == "" { + return apiFileTypes{ + file: apiFile, + types: apiSpec.Types, + }, nil + } + + packageName := strings.ToLower(strings.ReplaceAll(goPackage, "/", "")) + typesGoBytes, err := templatex.ParseTemplate("inner_types.go", map[string]any{ + "Types": typesGoString, + "Package": packageName, + }, []byte(`// Code generated by jzero. DO NOT EDIT. package {{.Package}} import ( @@ -43,38 +123,43 @@ var ( ) {{.Types}}`)) - if err != nil { - return err - } + if err != nil { + return apiFileTypes{}, err + } - _ = os.MkdirAll(filepath.Join("internal", "types", goPackage), 0o755) - process, err := gosimports.Process("", typesGoBytes, nil) - if err != nil { - return err - } - if err = os.WriteFile(filepath.Join("internal", "types", goPackage, "types.go"), process, 0o644); err != nil { - return err - } - } else { - allTypes = append(allTypes, apiSpecMap[apiFile].Types...) - } + return apiFileTypes{ + file: apiFile, + types: apiSpec.Types, + goPackage: goPackage, + typesGoBytes: typesGoBytes, + }, nil +} + +// writeTypesFile 写入 types.go 文件 +func (ja *JzeroApi) writeTypesFile(goPackage string, typesGoBytes []byte) error { + typesDir := filepath.Join("internal", "types", goPackage) + if err := os.MkdirAll(typesDir, 0o755); err != nil { + return err } - // 去除重复 - var realAllTypes []spec.Type - exist := make(map[string]struct{}) - for _, v := range allTypes { - if _, ok := exist[v.Name()]; ok { - continue - } - realAllTypes = append(realAllTypes, v) - exist[v.Name()] = struct{}{} + process, err := gosimports.Process("", typesGoBytes, nil) + if err != nil { + return err } - typesGoString, err := gogen.BuildTypes(realAllTypes) + return os.WriteFile(filepath.Join(typesDir, "types.go"), process, 0o644) +} + +// generateDefaultTypesFile 生成默认的 types.go 文件 +func (ja *JzeroApi) generateDefaultTypesFile(allTypes []spec.Type) error { + // 去重 + uniqueTypes := ja.deduplicateTypes(allTypes) + + typesGoString, err := gogen.BuildTypes(uniqueTypes) if err != nil { return err } + typesGoBytes, err := templatex.ParseTemplate("inner_types.go", map[string]any{ "Types": typesGoString, }, []byte(`// Code generated by jzero. DO NOT EDIT. @@ -92,14 +177,27 @@ var ( if err != nil { return err } + process, err := gosimports.Process("", typesGoBytes, nil) if err != nil { return err } - if err = os.WriteFile(filepath.Join("internal", "types", "types.go"), process, 0o644); err != nil { - return err + + return os.WriteFile(filepath.Join("internal", "types", "types.go"), process, 0o644) +} + +// deduplicateTypes 去重类型列表 +func (ja *JzeroApi) deduplicateTypes(types []spec.Type) []spec.Type { + var result []spec.Type + exist := make(map[string]struct{}) + for _, t := range types { + if _, ok := exist[t.Name()]; ok { + continue + } + result = append(result, t) + exist[t.Name()] = struct{}{} } - return nil + return result } func (ja *JzeroApi) updateHandlerImportedTypesPath(f *ast.File, fset *token.FileSet, file HandlerFile) error { diff --git a/cmd/jzero/internal/command/gen/genmodel/genmodel.go b/cmd/jzero/internal/command/gen/genmodel/genmodel.go index d3254da2f..7d6855ec1 100644 --- a/cmd/jzero/internal/command/gen/genmodel/genmodel.go +++ b/cmd/jzero/internal/command/gen/genmodel/genmodel.go @@ -264,7 +264,7 @@ func (jm *JzeroModel) Gen() error { } if !config.C.Quiet { - fmt.Println(console.Green("Done")) + fmt.Println(console.Green("Gen Model Done")) } return nil diff --git a/cmd/jzero/internal/command/gen/genmongo/genmongo.go b/cmd/jzero/internal/command/gen/genmongo/genmongo.go index 88c4107d2..117285568 100644 --- a/cmd/jzero/internal/command/gen/genmongo/genmongo.go +++ b/cmd/jzero/internal/command/gen/genmongo/genmongo.go @@ -118,7 +118,7 @@ func (jm *JzeroMongo) Gen() error { } if !config.C.Quiet { - fmt.Println(console.Green("Done")) + fmt.Println(console.Green("Gen Mongo Done")) } return nil diff --git a/cmd/jzero/internal/command/gen/genrpc/genrpc.go b/cmd/jzero/internal/command/gen/genrpc/genrpc.go index 50502f545..04e54629d 100644 --- a/cmd/jzero/internal/command/gen/genrpc/genrpc.go +++ b/cmd/jzero/internal/command/gen/genrpc/genrpc.go @@ -345,7 +345,7 @@ func (jr *JzeroRpc) Gen() (map[string]rpcparser.Proto, error) { if len(genCodeProtoFiles) > 0 { if !config.C.Quiet { - fmt.Println(console.Green("Done")) + fmt.Println(console.Green("Gen Rpc Done")) } } diff --git a/cmd/jzero/internal/command/gen/genrpc/middleware.go b/cmd/jzero/internal/command/gen/genrpc/middleware.go index 65098d71e..c7b5bdd02 100644 --- a/cmd/jzero/internal/command/gen/genrpc/middleware.go +++ b/cmd/jzero/internal/command/gen/genrpc/middleware.go @@ -211,7 +211,7 @@ func (jr *JzeroRpc) genApiMiddlewares(protoFiles []string) (err error) { } if !config.C.Quiet { - fmt.Printf("%s\n", console.Green("Done")) + fmt.Printf("%s\n", console.Green("Gen Rpc Middleware Done")) } return nil } diff --git a/cmd/jzero/internal/command/gen/genrpc/server.go b/cmd/jzero/internal/command/gen/genrpc/server.go index 4912fac64..aebc79181 100644 --- a/cmd/jzero/internal/command/gen/genrpc/server.go +++ b/cmd/jzero/internal/command/gen/genrpc/server.go @@ -45,7 +45,7 @@ func (jr *JzeroRpc) genServer(serverImports, pbImports ImportLines, registerServ return err } if !config.C.Quiet { - fmt.Printf("%s", console.Green("Done\n")) + fmt.Printf("%s", console.Green("Gen Server Done\n")) } return nil }