diff --git a/cmd/cache_helpers_test.go b/cmd/cache_helpers_test.go index 5ce7277..e47e543 100644 --- a/cmd/cache_helpers_test.go +++ b/cmd/cache_helpers_test.go @@ -203,9 +203,23 @@ func TestGenerateManualDI_WithoutCache(t *testing.T) { } func TestAddSetupMethodsToDI_WithCache(t *testing.T) { - t.Parallel() + // Not parallel: relies on os.Chdir so hasCacheDecorator can find the file. + origDir, _ := os.Getwd() + defer os.Chdir(origDir) + dir := t.TempDir() + require.NoError(t, os.Chdir(dir)) - content := `func (c *Container) setupRepositories() { + // Cache wiring is only emitted when a decorator file exists AND the + // container exposes a redisClient field; set up both preconditions. + repoDir := filepath.Join(DirInternal, DirRepository) + require.NoError(t, os.MkdirAll(repoDir, 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(repoDir, "cached_product_repository.go"), []byte("package repository\n"), 0o644)) + + content := `type Container struct { + redisClient *redis.Client +} + +func (c *Container) setupRepositories() { } func (c *Container) setupUseCases() { @@ -216,11 +230,40 @@ func (c *Container) setupHandlers() { // Getters` - result := addSetupMethodsToDI(content, "Product", "product", true) + result := addSetupMethodsToDI(content, "Product", "product", "postgres", true) assert.Contains(t, result, "baseProductRepo := repository.NewPostgresProductRepository(c.db)") assert.Contains(t, result, "c.productRepo = repository.NewCachedProductRepository(baseProductRepo, c.redisClient, 5*time.Minute)") } +// TestAddSetupMethodsToDI_CacheWithoutRedisClient verifies that when the +// container has no redisClient field, the cache decorator is NOT wired (which +// would otherwise reference an undefined field) and the bare repo is used. +func TestAddSetupMethodsToDI_CacheWithoutRedisClient(t *testing.T) { + origDir, _ := os.Getwd() + defer os.Chdir(origDir) + dir := t.TempDir() + require.NoError(t, os.Chdir(dir)) + + repoDir := filepath.Join(DirInternal, DirRepository) + require.NoError(t, os.MkdirAll(repoDir, 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(repoDir, "cached_product_repository.go"), []byte("package repository\n"), 0o644)) + + content := `func (c *Container) setupRepositories() { +} + +func (c *Container) setupUseCases() { +} + +func (c *Container) setupHandlers() { +} + +// Getters` + + result := addSetupMethodsToDI(content, "Product", "product", "postgres", true) + assert.Contains(t, result, "c.productRepo = repository.NewPostgresProductRepository(c.db)") + assert.NotContains(t, result, "NewCachedProductRepository") +} + func TestAddSetupMethodsToDI_WithoutCache(t *testing.T) { t.Parallel() @@ -235,7 +278,7 @@ func (c *Container) setupHandlers() { // Getters` - result := addSetupMethodsToDI(content, "Order", "order", false) + result := addSetupMethodsToDI(content, "Order", "order", "postgres", false) assert.Contains(t, result, "c.orderRepo = repository.NewPostgresOrderRepository(c.db)") assert.NotContains(t, result, "baseOrderRepo") assert.NotContains(t, result, "NewCachedOrderRepository") diff --git a/cmd/ci_templates.go b/cmd/ci_templates.go index 96e93f1..de27482 100644 --- a/cmd/ci_templates.go +++ b/cmd/ci_templates.go @@ -33,8 +33,8 @@ func generateTestWorkflow(data CITemplateData) string { --health-timeout 5s --health-retries 5` envBlock = ` - env: - DATABASE_URL: postgres://postgres:postgres@localhost:5432/testdb?sslmode=disable` + env: + DATABASE_URL: postgres://postgres:postgres@localhost:5432/testdb?sslmode=disable` } else if data.Database == "mysql" { svc = ` services: @@ -51,8 +51,8 @@ func generateTestWorkflow(data CITemplateData) string { --health-timeout 5s --health-retries 5` envBlock = ` - env: - DATABASE_URL: root:root@tcp(127.0.0.1:3306)/testdb` + env: + DATABASE_URL: root:root@tcp(127.0.0.1:3306)/testdb` } return fmt.Sprintf(`name: Test diff --git a/cmd/coverage_batch5_test.go b/cmd/coverage_batch5_test.go index b1f7b79..5bbae1e 100644 --- a/cmd/coverage_batch5_test.go +++ b/cmd/coverage_batch5_test.go @@ -215,7 +215,7 @@ func TestUpdateDIContainer_NoExistingFile(t *testing.T) { os.Chdir(t.TempDir()) sm := NewSafetyManager(true, false, false) - updateDIContainer("Product", false, sm) + updateDIContainer("Product", "postgres", false, sm) // Should handle gracefully when no di/container.go exists } @@ -228,7 +228,7 @@ func TestAutoIntegrateFeature_NoMainGo(t *testing.T) { os.Chdir(t.TempDir()) sm := NewSafetyManager(true, false, false) - autoIntegrateFeature("Product", "http", false, sm) + autoIntegrateFeature("Product", "http", "postgres", false, sm) // Should handle gracefully when main.go doesnt exist } diff --git a/cmd/coverage_batch6_test.go b/cmd/coverage_batch6_test.go index b5deba8..bd2a37c 100644 --- a/cmd/coverage_batch6_test.go +++ b/cmd/coverage_batch6_test.go @@ -211,7 +211,7 @@ func TestAddFeatureToDI_NoDIFile(t *testing.T) { os.Chdir(t.TempDir()) // No internal/di/container.go exists, should warn - addFeatureToDI("Product", false) + addFeatureToDI("Product", "postgres", false) } func TestAddFeatureToDI_AlreadyExists(t *testing.T) { @@ -233,7 +233,7 @@ type Container struct { ` os.WriteFile(filepath.Join(diDir, "container.go"), []byte(content), 0o644) - addFeatureToDI("Product", false) + addFeatureToDI("Product", "postgres", false) // "already in the DI container" path } @@ -273,7 +273,7 @@ func (c *Container) setupHandlers() { os.WriteFile(filepath.Join(diDir, "container.go"), []byte(content), 0o644) sm := NewSafetyManager(true, false, false) - addFeatureToDI("Order", false, sm) + addFeatureToDI("Order", "postgres", false, sm) } func TestSetupMainGoWithFeature_Coverage(t *testing.T) { diff --git a/cmd/di.go b/cmd/di.go index 7b366b4..30e3f2e 100644 --- a/cmd/di.go +++ b/cmd/di.go @@ -25,6 +25,32 @@ func dbHandleType(database string) (goType, importPath string) { return "*gorm.DB", "gorm.io/gorm" } +// repoConstructorPrefix returns the prefix used by the repository generator for +// a given database, so that the DI container, feature wiring and integration +// tests all reference the constructor that was actually generated +// (NewRepository). postgres/mysql/sqlite share the GORM-based +// "Postgres" implementation; postgres-json and sqlserver have their own +// GORM-based constructors but still take a *gorm.DB, so they wire against the +// same container handle. +func repoConstructorPrefix(database string) string { + switch database { + case DBPostgresJSON: + return "PostgresJSON" + case DBSQLServer: + return "SQLServer" + case dbMongoDB: + return "Mongo" + case DBElasticsearch: + return "Elasticsearch" + case DBDynamoDB: + return "DynamoDB" + default: + // postgres, mysql, sqlite and any unknown SQL backend all use the + // shared GORM "Postgres" repository constructor. + return "Postgres" + } +} + var diCmd = &cobra.Command{ Use: "di", Short: "Generate dependency injection container", @@ -190,14 +216,10 @@ func generateSetupRepositories(content *strings.Builder, features []string, data for _, feature := range features { featureLower := strings.ToLower(feature) - var repoConstructor string - switch database { - case dbMongoDB: - repoConstructor = fmt.Sprintf("repository.NewMongo%sRepository(c.db)", feature) - default: - // All SQL databases use the shared GORM constructor. - repoConstructor = fmt.Sprintf("repository.NewPostgres%sRepository(c.db)", feature) - } + // Reference the constructor the repository generator actually emits + // for this database (NewRepository), so the container + // compiles for every backend, not just Postgres. + repoConstructor := fmt.Sprintf("repository.New%s%sRepository(c.db)", repoConstructorPrefix(database), feature) // Only wrap with the Redis cache decorator when one was actually // generated for this entity (goca repository --cache). Emitting @@ -349,13 +371,7 @@ func writeWireSets(content *strings.Builder, features []string, database string) func writeRepositorySet(content *strings.Builder, features []string, database string) { content.WriteString("\tRepositorySet = wire.NewSet(\n") for _, feature := range features { - switch database { - case dbMongoDB: - fmt.Fprintf(content, "\t\trepository.NewMongo%sRepository,\n", feature) - default: - // All SQL databases use the shared GORM constructor. - fmt.Fprintf(content, "\t\trepository.NewPostgres%sRepository,\n", feature) - } + fmt.Fprintf(content, "\t\trepository.New%s%sRepository,\n", repoConstructorPrefix(database), feature) } content.WriteString("\t)\n\n") } diff --git a/cmd/entity_test_generator.go b/cmd/entity_test_generator.go index 1c9bf88..3a3b8bc 100644 --- a/cmd/entity_test_generator.go +++ b/cmd/entity_test_generator.go @@ -24,9 +24,14 @@ func generateEntityTests(domainDir, entityName string, fields []Field, validatio // Package declaration content.WriteString("package domain\n\n") - // Imports - no time import since we don't test timestamp fields + // Imports. The generated tests use time.Now() for any time.Time field that + // is exercised (user-declared fields as well as the timestamp/soft-delete + // fields), so the "time" import is required whenever such a field exists. content.WriteString("import (\n") content.WriteString("\t\"testing\"\n") + if fieldsNeedTimeImport(fields) { + content.WriteString("\t\"time\"\n") + } content.WriteString("\n\t\"github.com/stretchr/testify/assert\"\n") content.WriteString(")\n\n") // Generate validation tests if validation is enabled if validation { @@ -184,6 +189,15 @@ func generateConstructorTests(content *strings.Builder, entityName string, field continue } + // time.Time fields are initialized with time.Now(); two separate + // time.Now() calls never compare equal, so assert the field is simply + // non-zero instead of equal to a fresh timestamp. + if field.Type == "time.Time" { + fmt.Fprintf(content, "\tassert.False(t, %s.%s.IsZero(), \"%s should be set correctly\")\n", + entityLower, field.Name, field.Name) + continue + } + expectedValue := getValidFieldValue(field) fmt.Fprintf(content, "\tassert.Equal(t, %s, %s.%s, \"%s should be set correctly\")\n", expectedValue, entityLower, field.Name, field.Name) @@ -308,6 +322,23 @@ func generateFieldTests(content *strings.Builder, entityName string, fields []Fi } } +// isTestSkippedField reports whether a field is excluded from the generated +// test value-setting loops (the framework-managed fields). +func isTestSkippedField(name string) bool { + return name == "ID" || name == "CreatedAt" || name == "UpdatedAt" || name == "DeletedAt" +} + +// fieldsNeedTimeImport reports whether any field exercised by the generated +// tests is a time.Time, which requires importing the "time" package. +func fieldsNeedTimeImport(fields []Field) bool { + for _, f := range fields { + if f.Type == "time.Time" && !isTestSkippedField(f.Name) { + return true + } + } + return false +} + // Helper functions to generate test values func getValidFieldValue(field Field) string { @@ -373,8 +404,10 @@ func compositeOrZeroLiteral(fieldType string) string { if v, ok := generateDefaultSampleValue(fieldType, 1); ok { return v } - // Unknown named type: an empty composite literal is valid for structs. - return fieldType + "{}" + // Unknown named scalar type: the entity generator emits these as + // `type string` stubs, so a string conversion is a valid zero literal + // (an empty composite literal would not compile for a string-based type). + return fieldType + "(\"\")" } func getInvalidDescription(field Field) string { diff --git a/cmd/feature.go b/cmd/feature.go index 90e1e19..6a2b88c 100644 --- a/cmd/feature.go +++ b/cmd/feature.go @@ -144,7 +144,7 @@ including domain, use cases, repository and handlers in a single operation.`, // 7. Auto-integrate with DI and main.go ui.Step(7, "Integrating automatically...") - autoIntegrateFeature(featureName, handlers, cacheFlag, safetyMgr) + autoIntegrateFeature(featureName, handlers, effectiveDatabase, cacheFlag, safetyMgr) // 8. Handle dependencies ui.Step(8, "Managing dependencies...") @@ -336,9 +336,9 @@ func printFeatureStructure(featureName, handlers string) { } // autoIntegrateFeature automatically integrates the feature with DI and main.go. -func autoIntegrateFeature(featureName, handlers string, cache bool, sm ...*SafetyManager) { +func autoIntegrateFeature(featureName, handlers, database string, cache bool, sm ...*SafetyManager) { ui.Dim(" Updating DI container...") - updateDIContainer(featureName, cache, sm...) + updateDIContainer(featureName, database, cache, sm...) ui.Dim(" Registering HTTP routes...") if strings.Contains(handlers, "http") { @@ -349,23 +349,23 @@ func autoIntegrateFeature(featureName, handlers string, cache bool, sm ...*Safet } // updateDIContainer updates or creates DI container with new feature. -func updateDIContainer(featureName string, cache bool, sm ...*SafetyManager) { +func updateDIContainer(featureName, database string, cache bool, sm ...*SafetyManager) { // Check if DI container exists diPath := filepath.Join("internal", "di", "container.go") if _, err := os.Stat(diPath); os.IsNotExist(err) { // DI doesn't exist, create it with this feature ui.Dim(fmt.Sprintf(" Creating DI container for %s...", featureName)) - generateDI(featureName, "postgres", false, false, sm...) + generateDI(featureName, database, false, false, sm...) } else { // DI exists, update it to include new feature ui.Dim(" Updating existing DI container...") - addFeatureToDI(featureName, cache, sm...) + addFeatureToDI(featureName, database, cache, sm...) } } // addFeatureToDI adds a new feature to existing DI container. -func addFeatureToDI(featureName string, cache bool, sm ...*SafetyManager) { +func addFeatureToDI(featureName, database string, cache bool, sm ...*SafetyManager) { diPath := filepath.Join("internal", "di", "container.go") content, err := os.ReadFile(diPath) @@ -386,7 +386,7 @@ func addFeatureToDI(featureName string, cache bool, sm ...*SafetyManager) { ui.Dim(fmt.Sprintf(" Adding %s to DI container...", featureName)) updatedContent := addFieldsToDIContainer(contentStr, featureName, featureLower) - updatedContent = addSetupMethodsToDI(updatedContent, featureName, featureLower, cache) + updatedContent = addSetupMethodsToDI(updatedContent, featureName, featureLower, database, cache) updatedContent = addGetterMethodsToDI(updatedContent, featureName, featureLower) // This is an in-place merge of an existing container.go that we just read, @@ -418,16 +418,24 @@ func addFieldsToDIContainer(content, featureName, featureLower string) string { } // addSetupMethodsToDI adds setup method calls for the feature. -func addSetupMethodsToDI(content, featureName, featureLower string, cache bool) string { +func addSetupMethodsToDI(content, featureName, featureLower, database string, cache bool) string { fieldName := strings.ToLower(featureName[:1]) + featureName[1:] // camelCase - // Add repository setup + // Add repository setup. Reference the constructor the repository generator + // actually emits for this database so the container compiles on every + // backend, not just Postgres. + repoPrefix := repoConstructorPrefix(database) + // Only wire the Redis cache decorator when (a) a decorator was actually + // generated for this entity and (b) the existing container exposes a + // redisClient field. Otherwise emitting NewCached…/c.redisClient would + // reference symbols the container does not provide and break compilation. + wireCache := cache && hasCacheDecorator(featureName) && strings.Contains(content, "redisClient") var repoSetup string - if cache { - repoSetup = fmt.Sprintf("\tbase%sRepo := repository.NewPostgres%sRepository(c.db)\n", featureName, featureName) + if wireCache { + repoSetup = fmt.Sprintf("\tbase%sRepo := repository.New%s%sRepository(c.db)\n", featureName, repoPrefix, featureName) repoSetup += fmt.Sprintf("\tc.%sRepo = repository.NewCached%sRepository(base%sRepo, c.redisClient, 5*time.Minute)\n", featureLower, featureName, featureName) } else { - repoSetup = fmt.Sprintf("\tc.%sRepo = repository.NewPostgres%sRepository(c.db)\n", featureLower, featureName) + repoSetup = fmt.Sprintf("\tc.%sRepo = repository.New%s%sRepository(c.db)\n", featureLower, repoPrefix, featureName) } setupRepoEnd := "}\n\nfunc (c *Container) setupUseCases() {" content = strings.Replace(content, setupRepoEnd, repoSetup+setupRepoEnd, 1) @@ -721,7 +729,16 @@ func ensureMainGoImport(content, importLine string) string { // ensureContainerScaffold injects (once) the DI container instantiation and an // /api/v1 subrouter together with the route marker into main.go. func ensureContainerScaffold(content string) string { - if strings.Contains(content, "container := di.NewContainer(db)") { + // MongoDB projects expose a *mongo.Client named mongoClient (and no `db` + // variable); the container's NewContainer takes a *mongo.Database, so the + // handle must be derived from the client. Every other backend exposes a + // *gorm.DB named db that the container accepts directly. + dbArg := "db" + if strings.Contains(content, "mongoClient") && !strings.Contains(content, "\tdb ") { + dbArg = "mongoClient.Database(cfg.Database.Name)" + } + + if strings.Contains(content, fmt.Sprintf("container := di.NewContainer(%s)", dbArg)) { return content } @@ -732,7 +749,7 @@ func ensureContainerScaffold(content string) string { scaffold := anchor + "\n" + "\t// Dependency injection container\n" + - "\tcontainer := di.NewContainer(db)\n" + + fmt.Sprintf("\tcontainer := di.NewContainer(%s)\n", dbArg) + "\t_ = container\n\n" + "\t// API v1 routes\n" + "\tapiRouter := router.PathPrefix(\"/api/v1\").Subrouter()\n" + diff --git a/cmd/feature_helpers_test.go b/cmd/feature_helpers_test.go index 3ce1eee..386f0b9 100644 --- a/cmd/feature_helpers_test.go +++ b/cmd/feature_helpers_test.go @@ -55,7 +55,7 @@ func (c *Container) setupHandlers() { // Getters` - result := addSetupMethodsToDI(content, "Product", "product", false) + result := addSetupMethodsToDI(content, "Product", "product", "postgres", false) assert.Contains(t, result, "c.productRepo = repository.NewPostgresProductRepository(c.db)") assert.Contains(t, result, "c.productUC = usecase.NewProductService(c.productRepo)") assert.Contains(t, result, "c.productHandler = http.NewProductHandler(c.productUC)") diff --git a/cmd/integrate.go b/cmd/integrate.go index 38288cc..7663045 100644 --- a/cmd/integrate.go +++ b/cmd/integrate.go @@ -211,16 +211,17 @@ func snakeToPascal(s string) string { func createOrUpdateDIContainer(features []string, sm ...*SafetyManager) { diPath := filepath.Join("internal", "di", "container.go") + configIntegration := NewConfigIntegration() + _ = configIntegration.LoadConfigForProject() + database := configIntegration.GetDatabaseType("") + if _, err := os.Stat(diPath); os.IsNotExist(err) { ui.Dim(" Creating DI container...") - configIntegration := NewConfigIntegration() - _ = configIntegration.LoadConfigForProject() - database := configIntegration.GetDatabaseType("") generateDI(strings.Join(features, ","), database, false, false, sm...) } else { ui.Dim(" Updating existing DI container...") for _, feature := range features { - addFeatureToDI(feature, false) + addFeatureToDI(feature, database, false) } } } diff --git a/cmd/repository.go b/cmd/repository.go index 6b2c7cd..264a33d 100644 --- a/cmd/repository.go +++ b/cmd/repository.go @@ -227,7 +227,12 @@ func generateRepositoryImplementation(dir, entity, database string, cache, trans case DBMongoDB: generateMongoRepository(dir, entity, cache, transactions, sm...) case DBSQLite: - generateSQLiteRepository(dir, entity, cache, transactions, sm...) + // SQLite projects run on GORM (gorm.io/driver/sqlite) just like the + // other SQL backends: the generated main.go/DI container hold a + // *gorm.DB, so the repository must be GORM-based (NewPostgres + // Repository(*gorm.DB)) to wire correctly. A bespoke database/sql + // implementation (*sql.DB) could never be injected by that container. + generatePostgresRepository(dir, entity, cache, transactions, sm...) case DBSQLServer: generateSQLServerRepository(dir, entity, cache, transactions, sm...) case DBElasticsearch: diff --git a/cmd/repository_fields.go b/cmd/repository_fields.go index 513aa57..04b6743 100644 --- a/cmd/repository_fields.go +++ b/cmd/repository_fields.go @@ -80,7 +80,9 @@ func generateRepositoryImplementationWithFields(dir, entity, database string, fi case DBSQLServer: generateSQLServerRepositoryWithFields(dir, entity, fields, cache, transactions, sm...) case DBSQLite: - generateSQLiteRepositoryWithFields(dir, entity, fields, cache, transactions, sm...) + // SQLite uses GORM (gorm.io/driver/sqlite) so the generated *gorm.DB + // container can inject it; see generateRepositoryImplementation. + generatePostgresRepositoryWithFields(dir, entity, fields, cache, transactions, sm...) case DBElasticsearch: generateElasticsearchRepositoryWithFields(dir, entity, fields, cache, transactions, sm...) case DBDynamoDB: @@ -105,11 +107,6 @@ func generateSQLServerRepositoryWithFields(dir, entity string, fields []Field, c appendGormFinders(dir, "sqlserver_"+strings.ToLower(entity)+"_repository.go", fmt.Sprintf("sqlserver%sRepository", entity), entity, fields, sm...) } -func generateSQLiteRepositoryWithFields(dir, entity string, fields []Field, cache, transactions bool, sm ...*SafetyManager) { - generateSQLiteRepository(dir, entity, cache, transactions, sm...) - appendSQLiteFinders(dir, entity, fields, sm...) -} - func generateElasticsearchRepositoryWithFields(dir, entity string, fields []Field, cache, transactions bool, sm ...*SafetyManager) { generateElasticsearchRepository(dir, entity, cache, transactions, sm...) appendDelegatingFinders(dir, "elasticsearch_"+strings.ToLower(entity)+"_repository.go", fmt.Sprintf("elasticsearch%sRepository", entity), "e", entity, fields, sm...) @@ -134,32 +131,6 @@ func appendGormFinders(dir, file, repoName, entity string, fields []Field, sm .. appendToRepoFile(filepath.Join(dir, file), b.String(), nil, sm...) } -// appendSQLiteFinders appends raw-SQL per-field finders for the SQLite repo. -func appendSQLiteFinders(dir, entity string, fields []Field, sm ...*SafetyManager) { - methods := generateSearchMethods(fields, entity) - if len(methods) == 0 { - return - } - entityLower := strings.ToLower(entity) - repoName := fmt.Sprintf("sqlite%sRepository", entity) - var b strings.Builder - for _, m := range methods { - paramName := strings.ToLower(m.FieldName) - fmt.Fprintf(&b, "func (s *%s) %s(%s %s) %s {\n", repoName, m.MethodName, paramName, m.FieldType, m.ReturnType) - b.WriteString("\tvar data []byte\n") - fmt.Fprintf(&b, "\tquery := \"SELECT data FROM %ss WHERE json_extract(data, '$.%s') = ? LIMIT 1\"\n", entityLower, m.FieldName) - fmt.Fprintf(&b, "\tif err := s.db.QueryRow(query, %s).Scan(&data); err != nil {\n", paramName) - fmt.Fprintf(&b, "\t\tif err == sql.ErrNoRows {\n\t\t\treturn nil, fmt.Errorf(\"%s not found\")\n\t\t}\n", entity) - b.WriteString("\t\treturn nil, fmt.Errorf(\"failed to query: %w\", err)\n\t}\n") - fmt.Fprintf(&b, "\tvar %s domain.%s\n", entityLower, entity) - fmt.Fprintf(&b, "\tif err := json.Unmarshal(data, &%s); err != nil {\n", entityLower) - b.WriteString("\t\treturn nil, fmt.Errorf(\"failed to unmarshal: %w\", err)\n\t}\n") - fmt.Fprintf(&b, "\treturn &%s, nil\n", entityLower) - b.WriteString("}\n\n") - } - appendToRepoFile(filepath.Join(dir, "sqlite_"+entityLower+"_repository.go"), b.String(), nil, sm...) -} - // appendDelegatingFinders appends per-field finders that reuse FindAll and filter // in memory — used for backends (Elasticsearch, DynamoDB) where a dedicated query // per field is out of scope but the interface still requires the method. diff --git a/cmd/repository_other_db.go b/cmd/repository_other_db.go index 7f6c7ea..4c7e0db 100644 --- a/cmd/repository_other_db.go +++ b/cmd/repository_other_db.go @@ -95,7 +95,7 @@ func generateSQLServerRepository(dir, entity string, cache, transactions bool, s // Save method content.WriteString(fmt.Sprintf("func (s *%s) Save(%s *domain.%s) error {\n", repoName, entityLower, entity)) content.WriteString(fmt.Sprintf("\tif err := s.db.Create(%s).Error; err != nil {\n", entityLower)) - content.WriteString("\t\treturn fmt.Errorf(\"failed to save %s: %%w\", err)\n") + content.WriteString(fmt.Sprintf("\t\treturn fmt.Errorf(\"failed to save %s: %%w\", err)\n", entityLower)) content.WriteString("\t}\n") content.WriteString("\treturn nil\n") content.WriteString("}\n\n") @@ -105,7 +105,7 @@ func generateSQLServerRepository(dir, entity string, cache, transactions bool, s content.WriteString(fmt.Sprintf("\tvar %s domain.%s\n", entityLower, entity)) content.WriteString(fmt.Sprintf("\tif err := s.db.WithContext(s.db.Statement.Context).First(&%s, id).Error; err != nil {\n", entityLower)) content.WriteString("\t\tif err == gorm.ErrRecordNotFound {\n") - content.WriteString("\t\t\treturn nil, fmt.Errorf(\"%s not found\")\n") + content.WriteString(fmt.Sprintf("\t\t\treturn nil, fmt.Errorf(\"%s not found\")\n", entity)) content.WriteString("\t\t}\n") content.WriteString("\t\treturn nil, err\n") content.WriteString("\t}\n") @@ -115,7 +115,7 @@ func generateSQLServerRepository(dir, entity string, cache, transactions bool, s // Update method content.WriteString(fmt.Sprintf("func (s *%s) Update(%s *domain.%s) error {\n", repoName, entityLower, entity)) content.WriteString(fmt.Sprintf("\tif err := s.db.Save(%s).Error; err != nil {\n", entityLower)) - content.WriteString("\t\treturn fmt.Errorf(\"failed to update %s: %%w\", err)\n") + content.WriteString(fmt.Sprintf("\t\treturn fmt.Errorf(\"failed to update %s: %%w\", err)\n", entityLower)) content.WriteString("\t}\n") content.WriteString("\treturn nil\n") content.WriteString("}\n\n") @@ -123,7 +123,7 @@ func generateSQLServerRepository(dir, entity string, cache, transactions bool, s // Delete method content.WriteString(fmt.Sprintf("func (s *%s) Delete(id int) error {\n", repoName)) content.WriteString(fmt.Sprintf("\tif err := s.db.Delete(&domain.%s{}, id).Error; err != nil {\n", entity)) - content.WriteString("\t\treturn fmt.Errorf(\"failed to delete %s: %%w\", err)\n") + content.WriteString(fmt.Sprintf("\t\treturn fmt.Errorf(\"failed to delete %s: %%w\", err)\n", entityLower)) content.WriteString("\t}\n") content.WriteString("\treturn nil\n") content.WriteString("}\n\n") @@ -132,7 +132,7 @@ func generateSQLServerRepository(dir, entity string, cache, transactions bool, s content.WriteString(fmt.Sprintf("func (s *%s) FindAll() ([]domain.%s, error) {\n", repoName, entity)) content.WriteString(fmt.Sprintf("\tvar %ss []domain.%s\n", entityLower, entity)) content.WriteString(fmt.Sprintf("\tif err := s.db.Find(&%ss).Error; err != nil {\n", entityLower)) - content.WriteString("\t\treturn nil, fmt.Errorf(\"failed to fetch %ss: %%w\", err)\n") + content.WriteString(fmt.Sprintf("\t\treturn nil, fmt.Errorf(\"failed to fetch %ss: %%w\", err)\n", entityLower)) content.WriteString("\t}\n") content.WriteString(fmt.Sprintf("\treturn %ss, nil\n", entityLower)) content.WriteString("}\n") @@ -292,7 +292,7 @@ func generateDynamoDBRepository(dir, entity string, cache, transactions bool, sm // Save method content.WriteString(fmt.Sprintf("func (d *%s) Save(%s *domain.%s) error {\n", repoName, entityLower, entity)) content.WriteString(fmt.Sprintf("\tav, err := attributevalue.MarshalMap(%s)\n", entityLower)) - content.WriteString("\tif err != nil {\n\t\treturn fmt.Errorf(\"failed to marshal: %%w\", err)\n\t}\n") + content.WriteString("\tif err != nil {\n\t\treturn fmt.Errorf(\"failed to marshal: %w\", err)\n\t}\n") content.WriteString("\t_, err = d.client.PutItem(context.Background(), &dynamodb.PutItemInput{\n") content.WriteString("\t\tTableName: &d.tableName,\n") content.WriteString("\t\tItem: av,\n") @@ -308,10 +308,10 @@ func generateDynamoDBRepository(dir, entity string, cache, transactions bool, sm content.WriteString("\t\t\t\"id\": &types.AttributeValueMemberN{Value: strconv.Itoa(id)},\n") content.WriteString("\t\t},\n") content.WriteString("\t})\n") - content.WriteString("\tif err != nil {\n\t\treturn nil, fmt.Errorf(\"failed to get item: %%w\", err)\n\t}\n") + content.WriteString("\tif err != nil {\n\t\treturn nil, fmt.Errorf(\"failed to get item: %w\", err)\n\t}\n") content.WriteString(fmt.Sprintf("\tvar %s domain.%s\n", entityLower, entity)) content.WriteString(fmt.Sprintf("\terr = attributevalue.UnmarshalMap(result.Item, &%s)\n", entityLower)) - content.WriteString("\tif err != nil {\n\t\treturn nil, fmt.Errorf(\"failed to unmarshal: %%w\", err)\n\t}\n") + content.WriteString("\tif err != nil {\n\t\treturn nil, fmt.Errorf(\"failed to unmarshal: %w\", err)\n\t}\n") content.WriteString(fmt.Sprintf("\treturn &%s, nil\n", entityLower)) content.WriteString("}\n\n") @@ -336,10 +336,10 @@ func generateDynamoDBRepository(dir, entity string, cache, transactions bool, sm content.WriteString("\tresult, err := d.client.Scan(context.Background(), &dynamodb.ScanInput{\n") content.WriteString("\t\tTableName: &d.tableName,\n") content.WriteString("\t})\n") - content.WriteString("\tif err != nil {\n\t\treturn nil, fmt.Errorf(\"failed to scan: %%w\", err)\n\t}\n") + content.WriteString("\tif err != nil {\n\t\treturn nil, fmt.Errorf(\"failed to scan: %w\", err)\n\t}\n") content.WriteString(fmt.Sprintf("\tvar %ss []domain.%s\n", entityLower, entity)) content.WriteString(fmt.Sprintf("\terr = attributevalue.UnmarshalListOfMaps(result.Items, &%ss)\n", entityLower)) - content.WriteString("\tif err != nil {\n\t\treturn nil, fmt.Errorf(\"failed to unmarshal: %%w\", err)\n\t}\n") + content.WriteString("\tif err != nil {\n\t\treturn nil, fmt.Errorf(\"failed to unmarshal: %w\", err)\n\t}\n") content.WriteString(fmt.Sprintf("\treturn %ss, nil\n", entityLower)) content.WriteString("}\n") @@ -373,10 +373,10 @@ func generateSQLiteRepository(dir, entity string, cache, transactions bool, sm . content.WriteString(fmt.Sprintf("func (s *%s) Save(%s *domain.%s) error {\n", repoName, entityLower, entity)) content.WriteString("\tvar data []byte\n") content.WriteString(fmt.Sprintf("\tdata, err := json.Marshal(%s)\n", entityLower)) - content.WriteString("\tif err != nil {\n\t\treturn fmt.Errorf(\"failed to marshal: %%w\", err)\n\t}\n") + content.WriteString("\tif err != nil {\n\t\treturn fmt.Errorf(\"failed to marshal: %w\", err)\n\t}\n") content.WriteString(fmt.Sprintf("\tquery := \"INSERT INTO %ss (data) VALUES (?)\"\n", entityLower)) content.WriteString("\tif _, err := s.db.Exec(query, data); err != nil {\n") - content.WriteString("\t\treturn fmt.Errorf(\"failed to insert: %%w\", err)\n\t}\n") + content.WriteString("\t\treturn fmt.Errorf(\"failed to insert: %w\", err)\n\t}\n") content.WriteString("\treturn nil\n") content.WriteString("}\n\n") @@ -386,20 +386,20 @@ func generateSQLiteRepository(dir, entity string, cache, transactions bool, sm . content.WriteString(fmt.Sprintf("\tquery := \"SELECT data FROM %ss WHERE id = ? LIMIT 1\"\n", entityLower)) content.WriteString("\tif err := s.db.QueryRow(query, id).Scan(&data); err != nil {\n") content.WriteString(fmt.Sprintf("\t\tif err == sql.ErrNoRows {\n\t\t\treturn nil, fmt.Errorf(\"%s not found\")\n\t\t}\n", entity)) - content.WriteString("\t\treturn nil, fmt.Errorf(\"failed to query: %%w\", err)\n\t}\n") + content.WriteString("\t\treturn nil, fmt.Errorf(\"failed to query: %w\", err)\n\t}\n") content.WriteString(fmt.Sprintf("\tvar %s domain.%s\n", entityLower, entity)) content.WriteString(fmt.Sprintf("\tif err := json.Unmarshal(data, &%s); err != nil {\n", entityLower)) - content.WriteString("\t\treturn nil, fmt.Errorf(\"failed to unmarshal: %%w\", err)\n\t}\n") + content.WriteString("\t\treturn nil, fmt.Errorf(\"failed to unmarshal: %w\", err)\n\t}\n") content.WriteString(fmt.Sprintf("\treturn &%s, nil\n", entityLower)) content.WriteString("}\n\n") // Update method content.WriteString(fmt.Sprintf("func (s *%s) Update(%s *domain.%s) error {\n", repoName, entityLower, entity)) content.WriteString(fmt.Sprintf("\tdata, err := json.Marshal(%s)\n", entityLower)) - content.WriteString("\tif err != nil {\n\t\treturn fmt.Errorf(\"failed to marshal: %%w\", err)\n\t}\n") + content.WriteString("\tif err != nil {\n\t\treturn fmt.Errorf(\"failed to marshal: %w\", err)\n\t}\n") content.WriteString(fmt.Sprintf("\tquery := \"UPDATE %ss SET data = ? WHERE id = ?\"\n", entityLower)) content.WriteString(fmt.Sprintf("\tif _, err := s.db.Exec(query, data, %s.ID); err != nil {\n", entityLower)) - content.WriteString("\t\treturn fmt.Errorf(\"failed to update: %%w\", err)\n\t}\n") + content.WriteString("\t\treturn fmt.Errorf(\"failed to update: %w\", err)\n\t}\n") content.WriteString("\treturn nil\n") content.WriteString("}\n\n") @@ -407,7 +407,7 @@ func generateSQLiteRepository(dir, entity string, cache, transactions bool, sm . content.WriteString(fmt.Sprintf("func (s *%s) Delete(id int) error {\n", repoName)) content.WriteString(fmt.Sprintf("\tquery := \"DELETE FROM %ss WHERE id = ?\"\n", entityLower)) content.WriteString("\tif _, err := s.db.Exec(query, id); err != nil {\n") - content.WriteString("\t\treturn fmt.Errorf(\"failed to delete: %%w\", err)\n\t}\n") + content.WriteString("\t\treturn fmt.Errorf(\"failed to delete: %w\", err)\n\t}\n") content.WriteString("\treturn nil\n") content.WriteString("}\n\n") @@ -415,17 +415,17 @@ func generateSQLiteRepository(dir, entity string, cache, transactions bool, sm . content.WriteString(fmt.Sprintf("func (s *%s) FindAll() ([]domain.%s, error) {\n", repoName, entity)) content.WriteString(fmt.Sprintf("\tquery := \"SELECT data FROM %ss\"\n", entityLower)) content.WriteString("\trows, err := s.db.Query(query)\n") - content.WriteString("\tif err != nil {\n\t\treturn nil, fmt.Errorf(\"failed to query: %%w\", err)\n\t}\n") + content.WriteString("\tif err != nil {\n\t\treturn nil, fmt.Errorf(\"failed to query: %w\", err)\n\t}\n") content.WriteString("\tdefer rows.Close()\n") content.WriteString(fmt.Sprintf("\tvar %ss []domain.%s\n", entityLower, entity)) content.WriteString("\tfor rows.Next() {\n") content.WriteString("\t\tvar data []byte\n") - content.WriteString("\t\tif err := rows.Scan(&data); err != nil {\n\t\t\treturn nil, fmt.Errorf(\"failed to scan: %%w\", err)\n\t\t}\n") + content.WriteString("\t\tif err := rows.Scan(&data); err != nil {\n\t\t\treturn nil, fmt.Errorf(\"failed to scan: %w\", err)\n\t\t}\n") content.WriteString(fmt.Sprintf("\t\tvar %s domain.%s\n", entityLower, entity)) content.WriteString(fmt.Sprintf("\t\tif err := json.Unmarshal(data, &%s); err != nil {\n\t\t\treturn nil, fmt.Errorf(\"failed to unmarshal: %%w\", err)\n\t\t}\n", entityLower)) content.WriteString(fmt.Sprintf("\t\t%ss = append(%ss, %s)\n", entityLower, entityLower, entityLower)) content.WriteString("\t}\n") - content.WriteString("\tif err := rows.Err(); err != nil {\n\t\treturn nil, fmt.Errorf(\"rows error: %%w\", err)\n\t}\n") + content.WriteString("\tif err := rows.Err(); err != nil {\n\t\treturn nil, fmt.Errorf(\"rows error: %w\", err)\n\t}\n") content.WriteString(fmt.Sprintf("\treturn %ss, nil\n", entityLower)) content.WriteString("}\n") diff --git a/cmd/test_integration.go b/cmd/test_integration.go index d7eb7ae..3856ffa 100644 --- a/cmd/test_integration.go +++ b/cmd/test_integration.go @@ -181,7 +181,7 @@ func Test%[1]sIntegration(t *testing.T) { defer cleanupTestDatabase(t, db) // Initialize dependencies - repo := repository.NewPostgres%[1]sRepository(db) + repo := repository.New%[4]s%[1]sRepository(db) service := usecase.New%[1]sService(repo) @@ -274,7 +274,7 @@ func Test%[1]sRepositoryIntegration(t *testing.T) { db := setupTestDatabase(t, "%[2]s") defer cleanupTestDatabase(t, db) - repo := repository.NewPostgres%[1]sRepository(db) + repo := repository.New%[4]s%[1]sRepository(db) t.Run("SaveAndFindByID", func(t *testing.T) { %[3]s := &domain.%[1]s{ @@ -324,7 +324,7 @@ func Test%[1]sRepositoryIntegration(t *testing.T) { assert.Error(t, err) }) } -`, entityName, database, lowerEntity) +`, entityName, database, lowerEntity, repoConstructorPrefix(database)) return replaceIntegrationTestTODOs(content, fields, entityName) } diff --git a/cmd/usecase.go b/cmd/usecase.go index 191c090..b348414 100644 --- a/cmd/usecase.go +++ b/cmd/usecase.go @@ -201,6 +201,7 @@ func generateDTOFileWithFields(dir, entity string, operations []string, validati body := bodyB.String() usesErrors := strings.Contains(body, "errors.") usesStrings := strings.Contains(body, "strings.") + usesTime := strings.Contains(body, "time.") var content strings.Builder @@ -217,13 +218,16 @@ func generateDTOFileWithFields(dir, entity string, operations []string, validati return } - // Ensure errors/strings are imported only if the new DTOs use them. + // Ensure errors/strings/time are imported only if the new DTOs use them. if usesErrors { existingStr = ensureImportInDTOFile(existingStr, "errors", moduleName) } if usesStrings { existingStr = ensureImportInDTOFile(existingStr, "strings", moduleName) } + if usesTime { + existingStr = ensureImportInDTOFile(existingStr, "time", moduleName) + } // Add the existing content without the final newline content.WriteString(strings.TrimSuffix(existingStr, "\n")) @@ -239,7 +243,10 @@ func generateDTOFileWithFields(dir, entity string, operations []string, validati if usesStrings { content.WriteString("\t\"strings\"\n") } - if usesErrors || usesStrings { + if usesTime { + content.WriteString("\t\"time\"\n") + } + if usesErrors || usesStrings || usesTime { content.WriteString("\n") } content.WriteString(fmt.Sprintf("\t\"%s/internal/domain\"\n", getImportPath(moduleName)))