Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
296 changes: 213 additions & 83 deletions cmd/jzero/internal/command/gen/genapi/genapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,181 +183,311 @@ func (ja *JzeroApi) Gen() (map[string]*spec.ApiSpec, error) {
}

if !config.C.Quiet {
fmt.Println(console.Green("Done"))
fmt.Println(console.Green("Gen Api Done"))
}
return apiSpecMap, nil
}

func (ja *JzeroApi) generateApiCode(apiFiles []string, apiSpecMap map[string]*spec.ApiSpec, genCodeApiFiles []string, genCodeApiSpecMap map[string]*spec.ApiSpec, currentRoutesMap map[string][]spec.Route, importedFiles map[string]bool) error {
if err := ja.cleanHandlersDir(genCodeApiFiles, genCodeApiSpecMap); err != nil {
return err
}

templateDir, err := ja.prepareTemplateDir()
if err != nil {
return err
}
defer os.RemoveAll(templateDir)

allRoutesGoBody, err := ja.collectRoutesGoBody(apiFiles, apiSpecMap, currentRoutesMap, importedFiles)
if err != nil {
return err
}

if err := ja.generateCodeForApiFiles(genCodeApiFiles, apiSpecMap, importedFiles, templateDir); err != nil {
return err
}

if err := ja.patchHandlerAndLogicFiles(genCodeApiFiles, apiSpecMap, genCodeApiSpecMap); err != nil {
return err
}

if err := ja.generateRoutesGoFile(apiFiles, apiSpecMap, importedFiles, allRoutesGoBody); err != nil {
return err
}

if config.C.Gen.Route2Code {
if err := ja.generateRoute2CodeFile(apiSpecMap, currentRoutesMap, importedFiles); err != nil {
return err
}
}

return nil
}

// cleanHandlersDir 清理 handler 目录下的旧文件
func (ja *JzeroApi) cleanHandlersDir(genCodeApiFiles []string, genCodeApiSpecMap map[string]*spec.ApiSpec) error {
var eg errgroup.Group
for _, file := range genCodeApiFiles {
if parse, ok := genCodeApiSpecMap[file]; ok {
for _, group := range parse.Service.Groups {
dirFile, err := os.ReadDir(filepath.Join(config.C.Wd(), "internal", "handler", group.GetAnnotation("group")))
if err == nil {
for _, v := range dirFile {
if !v.IsDir() {
_ = os.Remove(filepath.Join(config.C.Wd(), "internal", "handler", group.GetAnnotation("group"), v.Name()))
}
apiSpec, ok := genCodeApiSpecMap[file]
if !ok {
continue
}

for _, group := range apiSpec.Service.Groups {
groupAnnotation := group.GetAnnotation("group")
if groupAnnotation == "" {
continue
}

handlerDir := filepath.Join(config.C.Wd(), "internal", "handler", groupAnnotation)
eg.Go(func() error {
dirEntries, err := os.ReadDir(handlerDir)
if err != nil {
return nil // 目录不存在或无法读取,忽略错误
}
for _, entry := range dirEntries {
if !entry.IsDir() {
_ = os.Remove(filepath.Join(handlerDir, entry.Name()))
}
}
}
return nil
})
}
}
return eg.Wait()
}

// 处理模板
var goctlHome string
// prepareTemplateDir 准备模板目录,返回临时目录路径
func (ja *JzeroApi) prepareTemplateDir() (string, error) {
tempDir, err := os.MkdirTemp(os.TempDir(), "")
if err != nil {
return err
return "", err
}
defer os.RemoveAll(tempDir)

// 先写入内置模板
err = embeded.WriteTemplateDir(filepath.Join("go-zero", "api"), filepath.Join(tempDir, "api"))
if err != nil {
return err
// 写入内置模板
if err := embeded.WriteTemplateDir(filepath.Join("go-zero", "api"), filepath.Join(tempDir, "api")); err != nil {
_ = os.RemoveAll(tempDir)
return "", err
}

// 如果用户自定义了模板,则复制覆盖
customTemplatePath := filepath.Join(config.C.Home, "go-zero", "api")
if pathx.FileExists(customTemplatePath) {
err = filex.CopyDir(customTemplatePath, filepath.Join(tempDir, "api"))
if err != nil {
return err
if err := filex.CopyDir(customTemplatePath, filepath.Join(tempDir, "api")); err != nil {
_ = os.RemoveAll(tempDir)
return "", err
}
}

goctlHome = tempDir
logx.Debugf("goctl_home = %s", goctlHome)
logx.Debugf("goctl_home = %s", tempDir)
return tempDir, nil
}

var handlerImports ImportLines
var allRoutesGoBody string
// collectRoutesGoBody 并发收集所有文件的 routesGoBody
func (ja *JzeroApi) collectRoutesGoBody(apiFiles []string, apiSpecMap map[string]*spec.ApiSpec, currentRoutesMap map[string][]spec.Route, importedFiles map[string]bool) (string, error) {
var allRoutesGoBodyMap sync.Map

var eg errgroup.Group
eg.SetLimit(len(apiFiles))
for _, v := range apiFiles {
// 跳过被 import 的文件
if importedFiles[v] {

for _, apiFile := range apiFiles {
if importedFiles[apiFile] {
continue
}

cv := v
currentFile := apiFile
eg.Go(func() error {
routesGoBody, err := ja.getRoutesGoBody(cv, apiSpecMap, currentRoutesMap)
routesGoBody, err := ja.getRoutesGoBody(currentFile, apiSpecMap, currentRoutesMap)
if err != nil {
return err
}
if routesGoBody != "" {
allRoutesGoBodyMap.Store(cv, routesGoBody)
allRoutesGoBodyMap.Store(currentFile, routesGoBody)
}
return nil
})
}

if err := eg.Wait(); err != nil {
return err
return "", err
}

for _, v := range apiFiles {
if s, ok := allRoutesGoBodyMap.Load(v); ok {
allRoutesGoBody += cast.ToString(s) + "\n"
var allRoutesGoBody strings.Builder
for _, apiFile := range apiFiles {
if body, ok := allRoutesGoBodyMap.Load(apiFile); ok {
allRoutesGoBody.WriteString(cast.ToString(body))
allRoutesGoBody.WriteString("\n")
}
}

for _, v := range genCodeApiFiles {
// 跳过被 import 的文件
if importedFiles[v] {
return allRoutesGoBody.String(), nil
}

// generateCodeForApiFiles 为所有 API 文件生成代码
func (ja *JzeroApi) generateCodeForApiFiles(genCodeApiFiles []string, apiSpecMap map[string]*spec.ApiSpec, importedFiles map[string]bool, templateDir string) error {
// 按 group 分组,同一 group 的文件串行处理,不同 group 并发处理
groupToFiles := make(map[string][]string)
for _, apiFile := range genCodeApiFiles {
if importedFiles[apiFile] {
continue
}
if len(apiSpecMap[apiFile].Service.Routes()) == 0 {
continue
}

if len(apiSpecMap[v].Service.Routes()) > 0 {
logicFiles, err := ja.getAllLogicFiles(v, apiSpecMap[v])
if err != nil {
return err
// 收集该文件的所有 group
groups := make(map[string]struct{})
for _, g := range apiSpecMap[apiFile].Service.Groups {
if groupAnnotation := g.GetAnnotation("group"); groupAnnotation != "" {
groups[groupAnnotation] = struct{}{}
}
}

handlerFiles, err := ja.getAllHandlerFiles(v, apiSpecMap[v])
if err != nil {
return err
// 如果没有 group,使用默认分组
if len(groups) == 0 {
groupToFiles[""] = append(groupToFiles[""], apiFile)
} else {
for group := range groups {
groupToFiles[group] = append(groupToFiles[group], apiFile)
}
}
}

// 并发处理不同 group
var eg errgroup.Group
for _, files := range groupToFiles {
currentFiles := files
eg.Go(func() error {
// 同一 group 内的文件串行处理
for _, apiFile := range currentFiles {
if !config.C.Quiet {
fmt.Printf("%s api file %s \n", console.Green("Using"), apiFile)
}

dir := "."
if !config.C.Quiet {
fmt.Printf("%s api file %s\n", console.Green("Using"), v)
if err := format.ApiFormatByPath(apiFile, false); err != nil {
return errors.Wrapf(err, "format api file: %s", apiFile)
}

command := fmt.Sprintf("goctl api go --api %s --dir %s --home %s --style %s", apiFile, ".", templateDir, config.C.Style)
logx.Debugf("command: %s", command)

if _, err := execx.Run(command, config.C.Wd()); err != nil {
return errors.Wrapf(err, "api file: %s", apiFile)
}
}
return nil
})
}

if err = format.ApiFormatByPath(v, false); err != nil {
return errors.Wrapf(err, "format api file: %s", v)
return eg.Wait()
}

// patchHandlerAndLogicFiles 并发 patch handler 和 logic 文件
func (ja *JzeroApi) patchHandlerAndLogicFiles(genCodeApiFiles []string, apiSpecMap map[string]*spec.ApiSpec, genCodeApiSpecMap map[string]*spec.ApiSpec) error {
var eg errgroup.Group

for _, apiFile := range genCodeApiFiles {
if len(apiSpecMap[apiFile].Service.Routes()) == 0 {
continue
}

currentFile := apiFile

eg.Go(func() error {
logicFiles, err := ja.getAllLogicFiles(currentFile, apiSpecMap[currentFile])
if err != nil {
return err
}

command := fmt.Sprintf("goctl api go --api %s --dir %s --home %s --style %s", v, dir, goctlHome, config.C.Style)
logx.Debugf("command: %s", command)
if _, err := execx.Run(command, config.C.Wd()); err != nil {
return errors.Wrapf(err, "api file: %s", v)
handlerFiles, err := ja.getAllHandlerFiles(currentFile, apiSpecMap[currentFile])
if err != nil {
return err
}

// patch handler
// Patch handler files
for _, file := range handlerFiles {
if _, ok := genCodeApiSpecMap[file.ApiFilepath]; ok {
if err = ja.patchHandler(file, genCodeApiSpecMap); err != nil {
return errors.Wrapf(err, "rewrite %s", file.Path)
}
}
}

// Patch logic files
for _, file := range logicFiles {
if _, ok := genCodeApiSpecMap[file.DescFilepath]; ok {
if err = ja.patchLogic(file, genCodeApiSpecMap); err != nil {
return errors.Wrapf(err, "rewrite %s", file.Path)
}
}
}
}

return nil
})
}

for _, v := range apiFiles {
// 跳过被 import 的文件
if importedFiles[v] {
return eg.Wait()
}

// generateRoutesGoFile 生成 routes.go 文件
func (ja *JzeroApi) generateRoutesGoFile(apiFiles []string, apiSpecMap map[string]*spec.ApiSpec, importedFiles map[string]bool, allRoutesGoBody string) error {
var handlerImports ImportLines

for _, apiFile := range apiFiles {
if importedFiles[apiFile] {
continue
}

for _, g := range apiSpecMap[v].Service.Groups {
if g.GetAnnotation("group") != "" {
handlerImports = append(handlerImports, fmt.Sprintf(`%s "%s/internal/handler/%s"`, strings.ReplaceAll(g.GetAnnotation("group"), "/", ""), ja.Module, g.GetAnnotation("group")))
for _, group := range apiSpecMap[apiFile].Service.Groups {
groupAnnotation := group.GetAnnotation("group")
if groupAnnotation != "" {
importPath := fmt.Sprintf(`%s "%s/internal/handler/%s"`,
strings.ReplaceAll(groupAnnotation, "/", ""), ja.Module, groupAnnotation)
handlerImports = append(handlerImports, importPath)
}
}
}

template, err := templatex.ParseTemplate(filepath.Join("api", "routes.go.tpl"), map[string]any{
"Routes": allRoutesGoBody,
"Module": ja.Module,
"HandlerImports": lo.Uniq(handlerImports),
}, embeded.ReadTemplateFile(filepath.Join("api", "routes.go.tpl")))
templateContent, err := templatex.ParseTemplate(
filepath.Join("api", "routes.go.tpl"),
map[string]any{
"Routes": allRoutesGoBody,
"Module": ja.Module,
"HandlerImports": lo.Uniq(handlerImports),
},
embeded.ReadTemplateFile(filepath.Join("api", "routes.go.tpl")),
)
if err != nil {
return err
}
process, err := gosimports.Process("", template, nil)

process, err := gosimports.Process("", templateContent, nil)
if err != nil {
return err
}
if err = os.WriteFile(filepath.Join("internal", "handler", "routes.go"), process, 0o644); err != nil {

return os.WriteFile(filepath.Join("internal", "handler", "routes.go"), process, 0o644)
}

// generateRoute2CodeFile 生成 route2code.go 文件
func (ja *JzeroApi) generateRoute2CodeFile(apiSpecMap map[string]*spec.ApiSpec, currentRoutesMap map[string][]spec.Route, importedFiles map[string]bool) error {
if !config.C.Quiet {
fmt.Printf("%s to generate internal/handler/route2code.go\n", console.Green("Start"))
}

route2CodeBytes, err := ja.genRoute2Code(apiSpecMap, currentRoutesMap, importedFiles)
if err != nil {
return err
}

if config.C.Gen.Route2Code {
if !config.C.Quiet {
fmt.Printf("%s to generate internal/handler/route2code.go\n", console.Green("Start"))
}
if route2CodeBytes, err := ja.genRoute2Code(apiSpecMap, currentRoutesMap, importedFiles); err != nil {
return err
} else {
if err = os.WriteFile(filepath.Join("internal", "handler", "route2code.go"), route2CodeBytes, 0o644); err != nil {
return err
}
}
if !config.C.Quiet {
fmt.Printf("%s", console.Green("Done\n"))
}
if err := os.WriteFile(filepath.Join("internal", "handler", "route2code.go"), route2CodeBytes, 0o644); err != nil {
return err
}

if !config.C.Quiet {
fmt.Printf("%s", console.Green("Gen Route2code Done\n"))
}

return nil
}
Loading
Loading