Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 47 additions & 4 deletions cmd/cache_helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,23 @@ func TestGenerateManualDI_WithoutCache(t *testing.T) {
}

func TestAddSetupMethodsToDI_WithCache(t *testing.T) {
t.Parallel()
// Not parallel: relies on os.Chdir so hasCacheDecorator can find the file.
origDir, _ := os.Getwd()
defer os.Chdir(origDir)
dir := t.TempDir()
require.NoError(t, os.Chdir(dir))

content := `func (c *Container) setupRepositories() {
// Cache wiring is only emitted when a decorator file exists AND the
// container exposes a redisClient field; set up both preconditions.
repoDir := filepath.Join(DirInternal, DirRepository)
require.NoError(t, os.MkdirAll(repoDir, 0o755))
require.NoError(t, os.WriteFile(filepath.Join(repoDir, "cached_product_repository.go"), []byte("package repository\n"), 0o644))

content := `type Container struct {
redisClient *redis.Client
}

func (c *Container) setupRepositories() {
}

func (c *Container) setupUseCases() {
Expand All @@ -216,11 +230,40 @@ func (c *Container) setupHandlers() {

// Getters`

result := addSetupMethodsToDI(content, "Product", "product", true)
result := addSetupMethodsToDI(content, "Product", "product", "postgres", true)
assert.Contains(t, result, "baseProductRepo := repository.NewPostgresProductRepository(c.db)")
assert.Contains(t, result, "c.productRepo = repository.NewCachedProductRepository(baseProductRepo, c.redisClient, 5*time.Minute)")
}

// TestAddSetupMethodsToDI_CacheWithoutRedisClient verifies that when the
// container has no redisClient field, the cache decorator is NOT wired (which
// would otherwise reference an undefined field) and the bare repo is used.
func TestAddSetupMethodsToDI_CacheWithoutRedisClient(t *testing.T) {
origDir, _ := os.Getwd()
defer os.Chdir(origDir)
dir := t.TempDir()
require.NoError(t, os.Chdir(dir))

repoDir := filepath.Join(DirInternal, DirRepository)
require.NoError(t, os.MkdirAll(repoDir, 0o755))
require.NoError(t, os.WriteFile(filepath.Join(repoDir, "cached_product_repository.go"), []byte("package repository\n"), 0o644))

content := `func (c *Container) setupRepositories() {
}

func (c *Container) setupUseCases() {
}

func (c *Container) setupHandlers() {
}

// Getters`

result := addSetupMethodsToDI(content, "Product", "product", "postgres", true)
assert.Contains(t, result, "c.productRepo = repository.NewPostgresProductRepository(c.db)")
assert.NotContains(t, result, "NewCachedProductRepository")
}

func TestAddSetupMethodsToDI_WithoutCache(t *testing.T) {
t.Parallel()

Expand All @@ -235,7 +278,7 @@ func (c *Container) setupHandlers() {

// Getters`

result := addSetupMethodsToDI(content, "Order", "order", false)
result := addSetupMethodsToDI(content, "Order", "order", "postgres", false)
assert.Contains(t, result, "c.orderRepo = repository.NewPostgresOrderRepository(c.db)")
assert.NotContains(t, result, "baseOrderRepo")
assert.NotContains(t, result, "NewCachedOrderRepository")
Expand Down
8 changes: 4 additions & 4 deletions cmd/ci_templates.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ func generateTestWorkflow(data CITemplateData) string {
--health-timeout 5s
--health-retries 5`
envBlock = `
env:
DATABASE_URL: postgres://postgres:postgres@localhost:5432/testdb?sslmode=disable`
env:
DATABASE_URL: postgres://postgres:postgres@localhost:5432/testdb?sslmode=disable`
} else if data.Database == "mysql" {
svc = `
services:
Expand All @@ -51,8 +51,8 @@ func generateTestWorkflow(data CITemplateData) string {
--health-timeout 5s
--health-retries 5`
envBlock = `
env:
DATABASE_URL: root:root@tcp(127.0.0.1:3306)/testdb`
env:
DATABASE_URL: root:root@tcp(127.0.0.1:3306)/testdb`
}

return fmt.Sprintf(`name: Test
Expand Down
4 changes: 2 additions & 2 deletions cmd/coverage_batch5_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ func TestUpdateDIContainer_NoExistingFile(t *testing.T) {
os.Chdir(t.TempDir())

sm := NewSafetyManager(true, false, false)
updateDIContainer("Product", false, sm)
updateDIContainer("Product", "postgres", false, sm)
// Should handle gracefully when no di/container.go exists
}

Expand All @@ -228,7 +228,7 @@ func TestAutoIntegrateFeature_NoMainGo(t *testing.T) {
os.Chdir(t.TempDir())

sm := NewSafetyManager(true, false, false)
autoIntegrateFeature("Product", "http", false, sm)
autoIntegrateFeature("Product", "http", "postgres", false, sm)
// Should handle gracefully when main.go doesnt exist
}

Expand Down
6 changes: 3 additions & 3 deletions cmd/coverage_batch6_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ func TestAddFeatureToDI_NoDIFile(t *testing.T) {
os.Chdir(t.TempDir())

// No internal/di/container.go exists, should warn
addFeatureToDI("Product", false)
addFeatureToDI("Product", "postgres", false)
}

func TestAddFeatureToDI_AlreadyExists(t *testing.T) {
Expand All @@ -233,7 +233,7 @@ type Container struct {
`
os.WriteFile(filepath.Join(diDir, "container.go"), []byte(content), 0o644)

addFeatureToDI("Product", false)
addFeatureToDI("Product", "postgres", false)
// "already in the DI container" path
}

Expand Down Expand Up @@ -273,7 +273,7 @@ func (c *Container) setupHandlers() {
os.WriteFile(filepath.Join(diDir, "container.go"), []byte(content), 0o644)

sm := NewSafetyManager(true, false, false)
addFeatureToDI("Order", false, sm)
addFeatureToDI("Order", "postgres", false, sm)
}

func TestSetupMainGoWithFeature_Coverage(t *testing.T) {
Expand Down
46 changes: 31 additions & 15 deletions cmd/di.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,32 @@ func dbHandleType(database string) (goType, importPath string) {
return "*gorm.DB", "gorm.io/gorm"
}

// repoConstructorPrefix returns the prefix used by the repository generator for
// a given database, so that the DI container, feature wiring and integration
// tests all reference the constructor that was actually generated
// (New<prefix><Entity>Repository). postgres/mysql/sqlite share the GORM-based
// "Postgres" implementation; postgres-json and sqlserver have their own
// GORM-based constructors but still take a *gorm.DB, so they wire against the
// same container handle.
func repoConstructorPrefix(database string) string {
switch database {
case DBPostgresJSON:
return "PostgresJSON"
case DBSQLServer:
return "SQLServer"
case dbMongoDB:
return "Mongo"
case DBElasticsearch:
return "Elasticsearch"
case DBDynamoDB:
return "DynamoDB"
default:
// postgres, mysql, sqlite and any unknown SQL backend all use the
// shared GORM "Postgres" repository constructor.
return "Postgres"
}
}

var diCmd = &cobra.Command{
Use: "di",
Short: "Generate dependency injection container",
Expand Down Expand Up @@ -190,14 +216,10 @@ func generateSetupRepositories(content *strings.Builder, features []string, data

for _, feature := range features {
featureLower := strings.ToLower(feature)
var repoConstructor string
switch database {
case dbMongoDB:
repoConstructor = fmt.Sprintf("repository.NewMongo%sRepository(c.db)", feature)
default:
// All SQL databases use the shared GORM constructor.
repoConstructor = fmt.Sprintf("repository.NewPostgres%sRepository(c.db)", feature)
}
// Reference the constructor the repository generator actually emits
// for this database (New<prefix><Entity>Repository), so the container
// compiles for every backend, not just Postgres.
repoConstructor := fmt.Sprintf("repository.New%s%sRepository(c.db)", repoConstructorPrefix(database), feature)

// Only wrap with the Redis cache decorator when one was actually
// generated for this entity (goca repository --cache). Emitting
Expand Down Expand Up @@ -349,13 +371,7 @@ func writeWireSets(content *strings.Builder, features []string, database string)
func writeRepositorySet(content *strings.Builder, features []string, database string) {
content.WriteString("\tRepositorySet = wire.NewSet(\n")
for _, feature := range features {
switch database {
case dbMongoDB:
fmt.Fprintf(content, "\t\trepository.NewMongo%sRepository,\n", feature)
default:
// All SQL databases use the shared GORM constructor.
fmt.Fprintf(content, "\t\trepository.NewPostgres%sRepository,\n", feature)
}
fmt.Fprintf(content, "\t\trepository.New%s%sRepository,\n", repoConstructorPrefix(database), feature)
}
content.WriteString("\t)\n\n")
}
Expand Down
39 changes: 36 additions & 3 deletions cmd/entity_test_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,14 @@ func generateEntityTests(domainDir, entityName string, fields []Field, validatio
// Package declaration
content.WriteString("package domain\n\n")

// Imports - no time import since we don't test timestamp fields
// Imports. The generated tests use time.Now() for any time.Time field that
// is exercised (user-declared fields as well as the timestamp/soft-delete
// fields), so the "time" import is required whenever such a field exists.
content.WriteString("import (\n")
content.WriteString("\t\"testing\"\n")
if fieldsNeedTimeImport(fields) {
content.WriteString("\t\"time\"\n")
}
content.WriteString("\n\t\"github.com/stretchr/testify/assert\"\n")
content.WriteString(")\n\n") // Generate validation tests if validation is enabled
if validation {
Expand Down Expand Up @@ -184,6 +189,15 @@ func generateConstructorTests(content *strings.Builder, entityName string, field
continue
}

// time.Time fields are initialized with time.Now(); two separate
// time.Now() calls never compare equal, so assert the field is simply
// non-zero instead of equal to a fresh timestamp.
if field.Type == "time.Time" {
fmt.Fprintf(content, "\tassert.False(t, %s.%s.IsZero(), \"%s should be set correctly\")\n",
entityLower, field.Name, field.Name)
continue
}

expectedValue := getValidFieldValue(field)
fmt.Fprintf(content, "\tassert.Equal(t, %s, %s.%s, \"%s should be set correctly\")\n",
expectedValue, entityLower, field.Name, field.Name)
Expand Down Expand Up @@ -308,6 +322,23 @@ func generateFieldTests(content *strings.Builder, entityName string, fields []Fi
}
}

// isTestSkippedField reports whether a field is excluded from the generated
// test value-setting loops (the framework-managed fields).
func isTestSkippedField(name string) bool {
return name == "ID" || name == "CreatedAt" || name == "UpdatedAt" || name == "DeletedAt"
}

// fieldsNeedTimeImport reports whether any field exercised by the generated
// tests is a time.Time, which requires importing the "time" package.
func fieldsNeedTimeImport(fields []Field) bool {
for _, f := range fields {
if f.Type == "time.Time" && !isTestSkippedField(f.Name) {
return true
}
}
return false
}

// Helper functions to generate test values

func getValidFieldValue(field Field) string {
Expand Down Expand Up @@ -373,8 +404,10 @@ func compositeOrZeroLiteral(fieldType string) string {
if v, ok := generateDefaultSampleValue(fieldType, 1); ok {
return v
}
// Unknown named type: an empty composite literal is valid for structs.
return fieldType + "{}"
// Unknown named scalar type: the entity generator emits these as
// `type <T> string` stubs, so a string conversion is a valid zero literal
// (an empty composite literal would not compile for a string-based type).
return fieldType + "(\"\")"
}

func getInvalidDescription(field Field) string {
Expand Down
47 changes: 32 additions & 15 deletions cmd/feature.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ including domain, use cases, repository and handlers in a single operation.`,

// 7. Auto-integrate with DI and main.go
ui.Step(7, "Integrating automatically...")
autoIntegrateFeature(featureName, handlers, cacheFlag, safetyMgr)
autoIntegrateFeature(featureName, handlers, effectiveDatabase, cacheFlag, safetyMgr)

// 8. Handle dependencies
ui.Step(8, "Managing dependencies...")
Expand Down Expand Up @@ -336,9 +336,9 @@ func printFeatureStructure(featureName, handlers string) {
}

// autoIntegrateFeature automatically integrates the feature with DI and main.go.
func autoIntegrateFeature(featureName, handlers string, cache bool, sm ...*SafetyManager) {
func autoIntegrateFeature(featureName, handlers, database string, cache bool, sm ...*SafetyManager) {
ui.Dim(" Updating DI container...")
updateDIContainer(featureName, cache, sm...)
updateDIContainer(featureName, database, cache, sm...)

ui.Dim(" Registering HTTP routes...")
if strings.Contains(handlers, "http") {
Expand All @@ -349,23 +349,23 @@ func autoIntegrateFeature(featureName, handlers string, cache bool, sm ...*Safet
}

// updateDIContainer updates or creates DI container with new feature.
func updateDIContainer(featureName string, cache bool, sm ...*SafetyManager) {
func updateDIContainer(featureName, database string, cache bool, sm ...*SafetyManager) {
// Check if DI container exists
diPath := filepath.Join("internal", "di", "container.go")

if _, err := os.Stat(diPath); os.IsNotExist(err) {
// DI doesn't exist, create it with this feature
ui.Dim(fmt.Sprintf(" Creating DI container for %s...", featureName))
generateDI(featureName, "postgres", false, false, sm...)
generateDI(featureName, database, false, false, sm...)
} else {
// DI exists, update it to include new feature
ui.Dim(" Updating existing DI container...")
addFeatureToDI(featureName, cache, sm...)
addFeatureToDI(featureName, database, cache, sm...)
}
}

// addFeatureToDI adds a new feature to existing DI container.
func addFeatureToDI(featureName string, cache bool, sm ...*SafetyManager) {
func addFeatureToDI(featureName, database string, cache bool, sm ...*SafetyManager) {
diPath := filepath.Join("internal", "di", "container.go")

content, err := os.ReadFile(diPath)
Expand All @@ -386,7 +386,7 @@ func addFeatureToDI(featureName string, cache bool, sm ...*SafetyManager) {
ui.Dim(fmt.Sprintf(" Adding %s to DI container...", featureName))

updatedContent := addFieldsToDIContainer(contentStr, featureName, featureLower)
updatedContent = addSetupMethodsToDI(updatedContent, featureName, featureLower, cache)
updatedContent = addSetupMethodsToDI(updatedContent, featureName, featureLower, database, cache)
updatedContent = addGetterMethodsToDI(updatedContent, featureName, featureLower)

// This is an in-place merge of an existing container.go that we just read,
Expand Down Expand Up @@ -418,16 +418,24 @@ func addFieldsToDIContainer(content, featureName, featureLower string) string {
}

// addSetupMethodsToDI adds setup method calls for the feature.
func addSetupMethodsToDI(content, featureName, featureLower string, cache bool) string {
func addSetupMethodsToDI(content, featureName, featureLower, database string, cache bool) string {
fieldName := strings.ToLower(featureName[:1]) + featureName[1:] // camelCase

// Add repository setup
// Add repository setup. Reference the constructor the repository generator
// actually emits for this database so the container compiles on every
// backend, not just Postgres.
repoPrefix := repoConstructorPrefix(database)
// Only wire the Redis cache decorator when (a) a decorator was actually
// generated for this entity and (b) the existing container exposes a
// redisClient field. Otherwise emitting NewCached…/c.redisClient would
// reference symbols the container does not provide and break compilation.
wireCache := cache && hasCacheDecorator(featureName) && strings.Contains(content, "redisClient")
var repoSetup string
if cache {
repoSetup = fmt.Sprintf("\tbase%sRepo := repository.NewPostgres%sRepository(c.db)\n", featureName, featureName)
if wireCache {
repoSetup = fmt.Sprintf("\tbase%sRepo := repository.New%s%sRepository(c.db)\n", featureName, repoPrefix, featureName)
repoSetup += fmt.Sprintf("\tc.%sRepo = repository.NewCached%sRepository(base%sRepo, c.redisClient, 5*time.Minute)\n", featureLower, featureName, featureName)
} else {
repoSetup = fmt.Sprintf("\tc.%sRepo = repository.NewPostgres%sRepository(c.db)\n", featureLower, featureName)
repoSetup = fmt.Sprintf("\tc.%sRepo = repository.New%s%sRepository(c.db)\n", featureLower, repoPrefix, featureName)
}
setupRepoEnd := "}\n\nfunc (c *Container) setupUseCases() {"
content = strings.Replace(content, setupRepoEnd, repoSetup+setupRepoEnd, 1)
Expand Down Expand Up @@ -721,7 +729,16 @@ func ensureMainGoImport(content, importLine string) string {
// ensureContainerScaffold injects (once) the DI container instantiation and an
// /api/v1 subrouter together with the route marker into main.go.
func ensureContainerScaffold(content string) string {
if strings.Contains(content, "container := di.NewContainer(db)") {
// MongoDB projects expose a *mongo.Client named mongoClient (and no `db`
// variable); the container's NewContainer takes a *mongo.Database, so the
// handle must be derived from the client. Every other backend exposes a
// *gorm.DB named db that the container accepts directly.
dbArg := "db"
if strings.Contains(content, "mongoClient") && !strings.Contains(content, "\tdb ") {
dbArg = "mongoClient.Database(cfg.Database.Name)"
}

if strings.Contains(content, fmt.Sprintf("container := di.NewContainer(%s)", dbArg)) {
return content
}

Expand All @@ -732,7 +749,7 @@ func ensureContainerScaffold(content string) string {

scaffold := anchor + "\n" +
"\t// Dependency injection container\n" +
"\tcontainer := di.NewContainer(db)\n" +
fmt.Sprintf("\tcontainer := di.NewContainer(%s)\n", dbArg) +
"\t_ = container\n\n" +
"\t// API v1 routes\n" +
"\tapiRouter := router.PathPrefix(\"/api/v1\").Subrouter()\n" +
Expand Down
Loading
Loading