diff --git a/README.md b/README.md index 91fae0e..6b87c17 100644 --- a/README.md +++ b/README.md @@ -66,14 +66,14 @@ seedstorm gaps \ ## Features -- **Schema self-discovery** — introspects tables, columns, PKs, FKs, enum values, UNIQUE and CHECK constraints; no manual editing required +- **Schema self-discovery** — introspects tables, columns, PKs, FKs, enum values, UNIQUE and CHECK constraints, generated columns, comments, defaults, and indexes; no manual editing required - **FK-aware seeding** — topological sort guarantees parent tables are seeded before children; handles nullable and non-nullable self-referential FKs with bounded depth, near-cycles, junction tables, and deep multi-level chains - **Constraint-aware faker mapping** — UNIQUE → `uuid`, CHECK IN → `randomstring(a,b,c)`, CHECK range → `number(min,max)`; seed data always satisfies your constraints - **Semantic faker** — maps column names (`email`, `first_name`, `price`, `city`…) to realistic `gofakeit` generators automatically - **Enum coverage** — every enum value appears at least `--rows` times, independently per column - **AI enrichment** — Gemini rewrites faker hints for domain-meaningful data; supply `--prompt` for richer context - **Gap analysis** — `gaps` shows which tables are empty with row counts and FK context; `--fill` seeds only the empty ones -- **Schema clone for test DBs** — copy schema-only structure from one connected Postgres/MySQL database into another matching local target, then seed it with safe fake data +- **Schema clone for test DBs** — copy schema-only structure from one connected Postgres/MySQL database into another matching local target, preserving compatible table metadata before seeding it with safe fake data - **Interactive TUI** — wizard for table selection, global config, self-reference depth, per-table row volumes, and review before seeding - **Web UI** — `seedstorm serve` exposes an interactive graph workspace with click-to-select tables, self-reference depth, per-table row overrides, truncate-only runs (`Rows = 0` + `truncate`), live SSE job logs, schema clone between connected DBs, multi-DB session switcher, and connection presets in `localStorage` - **Dry-run** — preview the seed plan and INSERT SQL without touching the database diff --git a/docs/commands.md b/docs/commands.md index 05e98c3..58b956e 100644 --- a/docs/commands.md +++ b/docs/commands.md @@ -274,7 +274,7 @@ seedstorm export --data data.yaml --format csv --out data.csv ## `clone-schema` -Copies schema-only table structure from a source database into a target database of the same engine. This is designed for local/test database setup before running `seed`; it recreates the metadata seedstorm understands: tables, columns, nullability, PKs, FKs, single-column UNIQUE constraints, enum values, and simple CHECK constraints. +Copies schema-only table structure from a source database into a target database of the same engine. This is designed for local/test database setup before running `seed`; it recreates compatible table metadata seedstorm understands: tables, columns, exact introspected column DDL types, nullability, defaults, stored generated columns, PKs, FKs, single-column UNIQUE constraints, multi-column indexes, enum values, simple CHECK constraints, and table/column comments. ```bash seedstorm clone-schema \ @@ -310,6 +310,8 @@ seedstorm clone-schema \ | `--dry-run` / `-n` | false | Print generated DDL, do not execute | | `--interactive` / `-i` | false | Confirm the clone in the terminal UI | +Boundaries: `clone-schema` is same-engine only. It does not attempt cross-engine translation, and it does not clone views, triggers, functions/procedures, partial/expression indexes, grants, ownership, or non-public/non-current schemas. + --- ## `serve` diff --git a/docs/schema.md b/docs/schema.md index b41e8e9..d247dd0 100644 --- a/docs/schema.md +++ b/docs/schema.md @@ -46,6 +46,7 @@ tables: | `faker` | Faker hint (see table below). Auto-assigned by `introspect`; overridden by `ai-enrich`. | | `nullable` | `true` if the column allows NULL | | `unique` | `true` if the column has a UNIQUE constraint (auto-sets faker to `uuid`) | +| `generated` | `true` for database-generated columns; seedstorm keeps them out of generated INSERT rows | ## Faker Hints Reference diff --git a/integration/integration_test.go b/integration/integration_test.go index 620593b..436fb71 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -187,13 +187,21 @@ func cloneSmokeSchema(t *testing.T, driver string, conn *sql.DB) { CREATE TABLE clone_users ( id integer PRIMARY KEY, email varchar(255) NOT NULL UNIQUE, - status varchar(20) NOT NULL CHECK (status IN ('active', 'blocked')) + status varchar(20) NOT NULL DEFAULT 'active' CHECK (status IN ('active', 'blocked')), + full_label varchar(300) GENERATED ALWAYS AS (email || ':' || status) STORED ); CREATE TABLE clone_orders ( id integer PRIMARY KEY, user_id integer NOT NULL REFERENCES clone_users(id), - total integer NOT NULL CHECK (total BETWEEN 1 AND 500) + subtotal numeric(10,2) NOT NULL DEFAULT 10.00, + tax numeric(10,2) NOT NULL DEFAULT 0.00, + total numeric(10,2) GENERATED ALWAYS AS (subtotal + tax) STORED, + quantity integer NOT NULL CHECK (quantity BETWEEN 1 AND 500) ); + CREATE INDEX idx_clone_orders_user_total ON clone_orders(user_id, total); + CREATE UNIQUE INDEX uq_clone_users_status_email ON clone_users(status, email); + COMMENT ON TABLE clone_users IS 'clone source users'; + COMMENT ON COLUMN clone_users.status IS 'workflow state'; `) return } @@ -205,14 +213,22 @@ func cloneSmokeSchema(t *testing.T, driver string, conn *sql.DB) { CREATE TABLE clone_users ( id integer PRIMARY KEY, email varchar(255) NOT NULL UNIQUE, - status varchar(20) NOT NULL CHECK (status IN ('active', 'blocked')) - ); + status varchar(20) NOT NULL DEFAULT 'active' CHECK (status IN ('active', 'blocked')), + full_label varchar(300) GENERATED ALWAYS AS (concat(email, ':', status)) STORED, + UNIQUE KEY uq_clone_users_status_email (status, email) + ) COMMENT='clone source users'; + ALTER TABLE clone_users MODIFY status varchar(20) NOT NULL DEFAULT 'active' COMMENT 'workflow state'; + CREATE INDEX idx_clone_users_status ON clone_users(status); CREATE TABLE clone_orders ( id integer PRIMARY KEY, user_id integer NOT NULL, - total integer NOT NULL CHECK (total BETWEEN 1 AND 500), + subtotal decimal(10,2) NOT NULL DEFAULT 10.00, + tax decimal(10,2) NOT NULL DEFAULT 0.00, + total decimal(10,2) GENERATED ALWAYS AS (subtotal + tax) STORED, + quantity integer NOT NULL CHECK (quantity BETWEEN 1 AND 500), FOREIGN KEY (user_id) REFERENCES clone_users(id) ); + CREATE INDEX idx_clone_orders_user_total ON clone_orders(user_id, total); `) } @@ -232,10 +248,11 @@ func assertCloneSchemaCanSeed(t *testing.T, driver string, conn *sql.DB, tables st := schema.Table{Columns: make(map[string]schema.Column, len(tbl.Columns))} for _, col := range tbl.Columns { sc := schema.Column{ - Type: col.Type, - PK: col.IsPK, - Nullable: col.IsNullable, - Faker: faker.MapColumnToFaker(driver, col), + Type: col.Type, + PK: col.IsPK, + Nullable: col.IsNullable, + Generated: col.Generated != "", + Faker: faker.MapColumnToFaker(driver, col), } if col.Name == "email" { sc.Faker = "email" @@ -272,6 +289,80 @@ func assertCloneSchemaCanSeed(t *testing.T, driver string, conn *sql.DB, tables } } +func assertCloneMetadata(t *testing.T, tables []db.Table) { + t.Helper() + byName := make(map[string]db.Table, len(tables)) + for _, tbl := range tables { + byName[tbl.Name] = tbl + } + users, ok := byName["clone_users"] + if !ok { + t.Fatal("clone_users missing from cloned metadata") + } + if users.Comment != "clone source users" { + t.Fatalf("clone_users comment = %q", users.Comment) + } + var status, label db.Column + for _, col := range users.Columns { + switch col.Name { + case "status": + status = col + case "full_label": + label = col + } + } + if status.Default == "" { + t.Fatal("clone_users.status default was not preserved") + } + if status.Comment != "workflow state" { + t.Fatalf("clone_users.status comment = %q", status.Comment) + } + if label.Generated == "" { + t.Fatal("clone_users.full_label generated expression was not preserved") + } + if !hasIndex(users.Indexes, "uq_clone_users_status_email", true, []string{"status", "email"}) { + t.Fatalf("multi-column unique index not preserved: %#v", users.Indexes) + } + orders := byName["clone_orders"] + if !hasIndex(orders.Indexes, "idx_clone_orders_user_total", false, []string{"user_id", "total"}) { + t.Fatalf("multi-column index not preserved: %#v", orders.Indexes) + } + var subtotal, total db.Column + for _, col := range orders.Columns { + switch col.Name { + case "subtotal": + subtotal = col + case "total": + total = col + } + } + if subtotal.Default == "" { + t.Fatal("clone_orders.subtotal default was not preserved") + } + if total.Generated == "" { + t.Fatal("clone_orders.total generated expression was not preserved") + } +} + +func hasIndex(indexes []db.Index, name string, unique bool, columns []string) bool { + for _, idx := range indexes { + if idx.Name != name || idx.Unique != unique || len(idx.Columns) != len(columns) { + continue + } + match := true + for i := range columns { + if idx.Columns[i] != columns[i] { + match = false + break + } + } + if match { + return true + } + } + return false +} + // buildAndSeed runs the full introspect → build schema → generate → seed pipeline. // It prints a summary at the end (not per-row during insert). func buildAndSeed(t *testing.T, label, driver, dsn string, conn *sql.DB) map[string][]map[string]interface{} { @@ -292,10 +383,11 @@ func buildAndSeed(t *testing.T, label, driver, dsn string, conn *sql.DB) map[str st := schema.Table{Columns: make(map[string]schema.Column, len(tbl.Columns))} for _, col := range tbl.Columns { sc := schema.Column{ - Type: col.Type, - PK: col.IsPK, - Nullable: col.IsNullable, - Faker: faker.MapColumnToFaker(driver, col), + Type: col.Type, + PK: col.IsPK, + Nullable: col.IsNullable, + Generated: col.Generated != "", + Faker: faker.MapColumnToFaker(driver, col), } if col.FK != nil { sc.FK = fmt.Sprintf("%s.%s", col.FK.TableName, col.FK.ColumnName) @@ -1508,7 +1600,7 @@ func TestPostgresIntegration(t *testing.T) { for _, tbl := range tables { st := schema.Table{Columns: make(map[string]schema.Column, len(tbl.Columns))} for _, col := range tbl.Columns { - sc := schema.Column{Type: col.Type, PK: col.IsPK, Nullable: col.IsNullable} + sc := schema.Column{Type: col.Type, PK: col.IsPK, Nullable: col.IsNullable, Generated: col.Generated != ""} if col.FK != nil { sc.FK = fmt.Sprintf("%s.%s", col.FK.TableName, col.FK.ColumnName) } @@ -1603,6 +1695,7 @@ func TestPostgresSchemaCloneDDL(t *testing.T) { if len(cloned) != 2 { t.Fatalf("cloned tables = %d, want 2", len(cloned)) } + assertCloneMetadata(t, cloned) assertCloneSchemaCanSeed(t, postgresDriver, conn, cloned) dropCloneSmokeSchema(t, postgresDriver, conn) } @@ -2744,7 +2837,7 @@ func TestMySQLIntegration(t *testing.T) { for _, tbl := range tables { st := schema.Table{Columns: make(map[string]schema.Column, len(tbl.Columns))} for _, col := range tbl.Columns { - sc := schema.Column{Type: col.Type, PK: col.IsPK, Nullable: col.IsNullable} + sc := schema.Column{Type: col.Type, PK: col.IsPK, Nullable: col.IsNullable, Generated: col.Generated != ""} if col.FK != nil { sc.FK = fmt.Sprintf("%s.%s", col.FK.TableName, col.FK.ColumnName) } @@ -2839,6 +2932,7 @@ func TestMySQLSchemaCloneDDL(t *testing.T) { if len(cloned) != 2 { t.Fatalf("cloned tables = %d, want 2", len(cloned)) } + assertCloneMetadata(t, cloned) assertCloneSchemaCanSeed(t, mysqlDriver, conn, cloned) dropCloneSmokeSchema(t, mysqlDriver, conn) } @@ -2882,10 +2976,11 @@ func seedL0(t *testing.T, driver, dsn string, conn *sql.DB) { st := schema.Table{Columns: make(map[string]schema.Column, len(tbl.Columns))} for _, col := range tbl.Columns { sc := schema.Column{ - Type: col.Type, - PK: col.IsPK, - Nullable: col.IsNullable, - Faker: faker.MapColumnToFaker(driver, col), + Type: col.Type, + PK: col.IsPK, + Nullable: col.IsNullable, + Generated: col.Generated != "", + Faker: faker.MapColumnToFaker(driver, col), } if col.FK != nil { sc.FK = fmt.Sprintf("%s.%s", col.FK.TableName, col.FK.ColumnName) @@ -3001,10 +3096,11 @@ func TestPostgresGaps(t *testing.T) { st := schema.Table{Columns: make(map[string]schema.Column, len(tbl.Columns))} for _, col := range tbl.Columns { sc := schema.Column{ - Type: col.Type, - PK: col.IsPK, - Nullable: col.IsNullable, - Faker: faker.MapColumnToFaker(postgresDriver, col), + Type: col.Type, + PK: col.IsPK, + Nullable: col.IsNullable, + Generated: col.Generated != "", + Faker: faker.MapColumnToFaker(postgresDriver, col), } if col.FK != nil { sc.FK = fmt.Sprintf("%s.%s", col.FK.TableName, col.FK.ColumnName) @@ -3185,10 +3281,11 @@ func TestMySQLGaps(t *testing.T) { st := schema.Table{Columns: make(map[string]schema.Column, len(tbl.Columns))} for _, col := range tbl.Columns { sc := schema.Column{ - Type: col.Type, - PK: col.IsPK, - Nullable: col.IsNullable, - Faker: faker.MapColumnToFaker(mysqlDriver, col), + Type: col.Type, + PK: col.IsPK, + Nullable: col.IsNullable, + Generated: col.Generated != "", + Faker: faker.MapColumnToFaker(mysqlDriver, col), } if col.FK != nil { sc.FK = fmt.Sprintf("%s.%s", col.FK.TableName, col.FK.ColumnName) diff --git a/internal/cli/introspect.go b/internal/cli/introspect.go index f2daf28..465a367 100644 --- a/internal/cli/introspect.go +++ b/internal/cli/introspect.go @@ -67,10 +67,11 @@ Outputs a schema.yaml that can be used for seeding or AI enrichment.`, } for _, c := range t.Columns { sc := schema.Column{ - Type: c.Type, - PK: c.IsPK, - Nullable: c.IsNullable, - Faker: faker.MapColumnToFaker(dbType, c), + Type: c.Type, + PK: c.IsPK, + Nullable: c.IsNullable, + Generated: c.Generated != "", + Faker: faker.MapColumnToFaker(dbType, c), } if c.FK != nil { sc.FK = fmt.Sprintf("%s.%s", c.FK.TableName, c.FK.ColumnName) diff --git a/internal/db/clone.go b/internal/db/clone.go index 9e22d45..7593439 100644 --- a/internal/db/clone.go +++ b/internal/db/clone.go @@ -97,6 +97,8 @@ func BuildSchemaDDL(tables []Table, dbType string, dropExisting bool) ([]string, stmts = append(stmts, stmt) } stmts = append(stmts, buildForeignKeyDDL(ordered, dbType)...) + stmts = append(stmts, buildIndexDDL(ordered, dbType)...) + stmts = append(stmts, buildCommentDDL(ordered, dbType)...) return stmts, nil } @@ -152,22 +154,7 @@ func buildCreateTable(table Table, dbType string) (string, error) { if col.Name == "" { return "", fmt.Errorf("table %s has empty column name", table.Name) } - def := fmt.Sprintf("%s %s", QuoteIdent(col.Name, dbType), cloneColumnType(col, dbType)) - if !col.IsNullable || col.IsPK { - def += " NOT NULL" - } - if col.Unique && !col.IsPK { - def += " UNIQUE" - } - if len(col.CheckValues) > 0 { - def += " CHECK (" + QuoteIdent(col.Name, dbType) + " IN (" + quotedLiterals(col.CheckValues) + "))" - } - if dbType == "pgx" && len(col.EnumValues) > 0 { - def += " CHECK (" + QuoteIdent(col.Name, dbType) + " IN (" + quotedLiterals(col.EnumValues) + "))" - } - if col.CheckMin != nil && col.CheckMax != nil { - def += fmt.Sprintf(" CHECK (%s BETWEEN %d AND %d)", QuoteIdent(col.Name, dbType), *col.CheckMin, *col.CheckMax) - } + def := buildColumnDDL(col, dbType) defs = append(defs, def) if col.IsPK { pkCols = append(pkCols, QuoteIdent(col.Name, dbType)) @@ -176,7 +163,51 @@ func buildCreateTable(table Table, dbType string) (string, error) { if len(pkCols) > 0 { defs = append(defs, "PRIMARY KEY ("+strings.Join(pkCols, ", ")+")") } - return fmt.Sprintf("CREATE TABLE %s (\n %s\n)", QuoteIdent(table.Name, dbType), strings.Join(defs, ",\n ")), nil + stmt := fmt.Sprintf("CREATE TABLE %s (\n %s\n)", QuoteIdent(table.Name, dbType), strings.Join(defs, ",\n ")) + if dbType == "mysql" && table.Comment != "" { + stmt += " COMMENT=" + quoteStringLiteral(table.Comment) + } + return stmt, nil +} + +func buildColumnDDL(col Column, dbType string) string { + def := fmt.Sprintf("%s %s", QuoteIdent(col.Name, dbType), cloneColumnType(col, dbType)) + if col.Generated != "" { + def += " GENERATED ALWAYS AS (" + col.Generated + ") STORED" + if dbType == "mysql" && col.Comment != "" { + def += " COMMENT " + quoteStringLiteral(col.Comment) + } + return def + } + if col.AutoIncrement { + if dbType == "pgx" { + def += " GENERATED BY DEFAULT AS IDENTITY" + } else { + def += " AUTO_INCREMENT" + } + } + if !col.IsNullable || col.IsPK { + def += " NOT NULL" + } + if col.Default != "" && !col.AutoIncrement { + def += " DEFAULT " + cloneColumnDefault(col, dbType) + } + if col.Unique && !col.IsPK { + def += " UNIQUE" + } + if len(col.CheckValues) > 0 { + def += " CHECK (" + QuoteIdent(col.Name, dbType) + " IN (" + quotedLiterals(col.CheckValues) + "))" + } + if dbType == "pgx" && len(col.EnumValues) > 0 { + def += " CHECK (" + QuoteIdent(col.Name, dbType) + " IN (" + quotedLiterals(col.EnumValues) + "))" + } + if col.CheckMin != nil && col.CheckMax != nil { + def += fmt.Sprintf(" CHECK (%s BETWEEN %d AND %d)", QuoteIdent(col.Name, dbType), *col.CheckMin, *col.CheckMax) + } + if dbType == "mysql" && col.Comment != "" { + def += " COMMENT " + quoteStringLiteral(col.Comment) + } + return def } func buildForeignKeyDDL(tables []Table, dbType string) []string { @@ -198,17 +229,71 @@ func buildForeignKeyDDL(tables []Table, dbType string) []string { return stmts } -func cloneColumnType(col Column, dbType string) string { - t := strings.ToLower(strings.TrimSpace(col.Type)) - if t == "" { - t = "text" +func buildIndexDDL(tables []Table, dbType string) []string { + var stmts []string + for _, table := range tables { + indexes := append([]Index(nil), table.Indexes...) + sort.Slice(indexes, func(i, j int) bool { return indexes[i].Name < indexes[j].Name }) + for _, idx := range indexes { + if idx.Name == "" || len(idx.Columns) == 0 { + continue + } + cols := make([]string, 0, len(idx.Columns)) + for _, col := range idx.Columns { + cols = append(cols, QuoteIdent(col, dbType)) + } + unique := "" + if idx.Unique { + unique = "UNIQUE " + } + stmts = append(stmts, fmt.Sprintf("CREATE %sINDEX %s ON %s (%s)", + unique, + QuoteIdent(idx.Name, dbType), + QuoteIdent(table.Name, dbType), + strings.Join(cols, ", "))) + } } + return stmts +} + +func buildCommentDDL(tables []Table, dbType string) []string { + if dbType != "pgx" { + return nil + } + var stmts []string + for _, table := range tables { + if table.Comment != "" { + stmts = append(stmts, fmt.Sprintf("COMMENT ON TABLE %s IS %s", QuoteIdent(table.Name, dbType), quoteStringLiteral(table.Comment))) + } + cols := append([]Column(nil), table.Columns...) + sort.Slice(cols, func(i, j int) bool { return cols[i].Name < cols[j].Name }) + for _, col := range cols { + if col.Comment == "" { + continue + } + stmts = append(stmts, fmt.Sprintf("COMMENT ON COLUMN %s.%s IS %s", + QuoteIdent(table.Name, dbType), + QuoteIdent(col.Name, dbType), + quoteStringLiteral(col.Comment))) + } + } + return stmts +} + +func cloneColumnType(col Column, dbType string) string { if dbType == "mysql" && len(col.EnumValues) > 0 { return "ENUM(" + quotedLiterals(col.EnumValues) + ")" } if dbType == "pgx" && len(col.EnumValues) > 0 { return "TEXT" } + if typ := strings.TrimSpace(col.DDLType); typ != "" { + return typ + } + t := strings.ToLower(strings.TrimSpace(col.Type)) + if t == "" { + t = "text" + } switch t { case "character varying": if dbType == "mysql" { @@ -251,6 +336,15 @@ func cloneColumnType(col Column, dbType string) string { return "TEXT" } +func cloneColumnDefault(col Column, dbType string) string { + if dbType == "pgx" && len(col.EnumValues) > 0 { + if idx := strings.Index(col.Default, "::"); idx > 0 { + return col.Default[:idx] + } + } + return col.Default +} + func quotedLiterals(values []string) string { parts := make([]string, len(values)) for i, v := range values { @@ -259,6 +353,10 @@ func quotedLiterals(values []string) string { return strings.Join(parts, ", ") } +func quoteStringLiteral(value string) string { + return "'" + strings.ReplaceAll(value, "'", "''") + "'" +} + func orderTablesByName(tables []Table) []Table { out := append([]Table(nil), tables...) sort.Slice(out, func(i, j int) bool { return out[i].Name < out[j].Name }) diff --git a/internal/db/clone_test.go b/internal/db/clone_test.go index 4a43b3a..b5c02bf 100644 --- a/internal/db/clone_test.go +++ b/internal/db/clone_test.go @@ -126,6 +126,33 @@ func TestBuildSchemaDDL_postgresEnumValuesBecomeCheck(t *testing.T) { } } +func TestBuildSchemaDDL_postgresEnumDDLTypeDoesNotRequireSourceEnumType(t *testing.T) { + tables := []Table{{ + Name: "coupons", + Columns: []Column{ + {Name: "id", DDLType: "integer", Type: "integer", IsPK: true}, + {Name: "discount_type", DDLType: "discount_type", Type: "discount_type", EnumValues: []string{"percentage", "fixed"}, Default: "'percentage'::discount_type"}, + }, + }} + stmts, err := BuildSchemaDDL(tables, "pgx", false) + if err != nil { + t.Fatalf("BuildSchemaDDL: %v", err) + } + ddl := strings.Join(stmts, "\n") + if strings.Contains(ddl, `"discount_type" discount_type`) { + t.Fatalf("clone DDL should not require source enum type to exist:\n%s", ddl) + } + if !strings.Contains(ddl, `CHECK ("discount_type" IN ('percentage', 'fixed'))`) { + t.Fatalf("expected enum to clone as TEXT with CHECK:\n%s", ddl) + } + if strings.Contains(ddl, "::discount_type") { + t.Fatalf("clone DDL should not require source enum type in defaults:\n%s", ddl) + } + if !strings.Contains(ddl, `"discount_type" TEXT NOT NULL DEFAULT 'percentage' CHECK`) { + t.Fatalf("expected enum default cast to be stripped:\n%s", ddl) + } +} + func TestBuildSchemaDDL_mysqlColumnConstraintOrder(t *testing.T) { tables := []Table{{ Name: "users", @@ -149,3 +176,136 @@ func TestBuildSchemaDDL_mysqlColumnConstraintOrder(t *testing.T) { } } } + +func TestBuildSchemaDDL_postgresSerialDefaultBecomesIdentity(t *testing.T) { + tables := []Table{{ + Name: "addresses", + Columns: []Column{ + {Name: "id", DDLType: "integer", Type: "integer", IsPK: true, AutoIncrement: true, Default: "nextval('addresses_id_seq'::regclass)"}, + }, + }} + stmts, err := BuildSchemaDDL(tables, "pgx", false) + if err != nil { + t.Fatalf("BuildSchemaDDL: %v", err) + } + ddl := strings.Join(stmts, "\n") + if strings.Contains(ddl, "addresses_id_seq") { + t.Fatalf("source sequence name should not be copied into clone DDL:\n%s", ddl) + } + if !strings.Contains(ddl, `"id" integer GENERATED BY DEFAULT AS IDENTITY NOT NULL`) { + t.Fatalf("serial PK should become identity column:\n%s", ddl) + } +} + +func TestBuildSchemaDDL_mysqlAutoIncrement(t *testing.T) { + tables := []Table{{ + Name: "addresses", + Columns: []Column{ + {Name: "id", DDLType: "int", Type: "integer", IsPK: true, AutoIncrement: true}, + }, + }} + stmts, err := BuildSchemaDDL(tables, "mysql", false) + if err != nil { + t.Fatalf("BuildSchemaDDL: %v", err) + } + ddl := strings.Join(stmts, "\n") + if !strings.Contains(ddl, "`id` int AUTO_INCREMENT NOT NULL") { + t.Fatalf("auto_increment should be preserved:\n%s", ddl) + } +} + +func TestIsPostgresSerialDefault(t *testing.T) { + if !isPostgresSerialDefault("nextval('addresses_id_seq'::regclass)") { + t.Fatal("expected nextval default to be detected as serial") + } + if isPostgresSerialDefault("'pending'::text") { + t.Fatal("ordinary string default should not be serial") + } +} + +func TestPostgresIndexQueryAvoidsArrayPositionOnInt2Vector(t *testing.T) { + query := postgresIndexQuery() + if strings.Contains(query, "array_position(ix.indkey") { + t.Fatalf("Postgres index query must avoid array_position(int2vector, ...) for PG13 compatibility:\n%s", query) + } + if !strings.Contains(query, "FROM unnest(ix.indkey)") { + t.Fatalf("Postgres index query should filter expression indexes through unnest(ix.indkey):\n%s", query) + } +} + +func TestBuildSchemaDDL_preservesDefaultsGeneratedIndexesAndComments(t *testing.T) { + tables := []Table{{ + Name: "orders", + Comment: "order table", + Columns: []Column{ + {Name: "id", DDLType: "integer", Type: "integer", IsPK: true}, + {Name: "status", DDLType: "varchar(20)", Type: "character varying", IsNullable: false, Default: "'new'::character varying", Comment: "workflow state"}, + {Name: "subtotal", DDLType: "numeric(10,2)", Type: "numeric", IsNullable: false, Default: "0"}, + {Name: "tax", DDLType: "numeric(10,2)", Type: "numeric", IsNullable: false, Default: "0"}, + {Name: "total", DDLType: "numeric(10,2)", Type: "numeric", Generated: "(subtotal + tax)"}, + }, + Indexes: []Index{ + {Name: "idx_orders_status_subtotal", Columns: []string{"status", "subtotal"}}, + {Name: "uq_orders_status_total", Columns: []string{"status", "total"}, Unique: true}, + }, + }} + stmts, err := BuildSchemaDDL(tables, "pgx", false) + if err != nil { + t.Fatalf("BuildSchemaDDL: %v", err) + } + ddl := strings.Join(stmts, "\n") + for _, want := range []string{ + `"status" varchar(20) NOT NULL DEFAULT 'new'::character varying`, + `"subtotal" numeric(10,2) NOT NULL DEFAULT 0`, + `"total" numeric(10,2) GENERATED ALWAYS AS ((subtotal + tax)) STORED`, + `CREATE INDEX "idx_orders_status_subtotal" ON "orders" ("status", "subtotal")`, + `CREATE UNIQUE INDEX "uq_orders_status_total" ON "orders" ("status", "total")`, + `COMMENT ON TABLE "orders" IS 'order table'`, + `COMMENT ON COLUMN "orders"."status" IS 'workflow state'`, + } { + if !strings.Contains(ddl, want) { + t.Fatalf("missing %q in:\n%s", want, ddl) + } + } +} + +func TestBuildSchemaDDL_mysqlPreservesDefaultsGeneratedIndexesAndComments(t *testing.T) { + tables := []Table{{ + Name: "orders", + Comment: "order table", + Columns: []Column{ + {Name: "id", DDLType: "int", Type: "integer", IsPK: true}, + {Name: "status", DDLType: "varchar(20)", Type: "varchar", IsNullable: false, Default: "'new'", Comment: "workflow state"}, + {Name: "subtotal", DDLType: "decimal(10,2)", Type: "decimal", IsNullable: false, Default: "0.00"}, + {Name: "tax", DDLType: "decimal(10,2)", Type: "decimal", IsNullable: false, Default: "0.00"}, + {Name: "total", DDLType: "decimal(10,2)", Type: "decimal", Generated: "`subtotal` + `tax`"}, + }, + Indexes: []Index{ + {Name: "idx_orders_status_subtotal", Columns: []string{"status", "subtotal"}}, + }, + }} + stmts, err := BuildSchemaDDL(tables, "mysql", false) + if err != nil { + t.Fatalf("BuildSchemaDDL: %v", err) + } + ddl := strings.Join(stmts, "\n") + for _, want := range []string{ + "`status` varchar(20) NOT NULL DEFAULT 'new' COMMENT 'workflow state'", + "`subtotal` decimal(10,2) NOT NULL DEFAULT 0.00", + "`total` decimal(10,2) GENERATED ALWAYS AS (`subtotal` + `tax`) STORED", + "CREATE INDEX `idx_orders_status_subtotal` ON `orders` (`status`, `subtotal`)", + "COMMENT='order table'", + } { + if !strings.Contains(ddl, want) { + t.Fatalf("missing %q in:\n%s", want, ddl) + } + } +} + +func TestMySQLGeneratedExpressionNormalizesEscapedStringLiterals(t *testing.T) { + got := mysqlGeneratedExpression("concat(`email`,_utf8mb4\\':\\',`status`)") + want := "concat(`email`,_utf8mb4':',`status`)" + if got != want { + t.Fatalf("mysqlGeneratedExpression() = %q, want %q", got, want) + } +} diff --git a/internal/db/mysql.go b/internal/db/mysql.go index 1372ded..9f57fb4 100644 --- a/internal/db/mysql.go +++ b/internal/db/mysql.go @@ -63,13 +63,23 @@ func introspectMySQL(db *sql.DB) ([]Table, error) { } } + indexMap, err := mysqlIndexMap(db, dbName) + if err != nil { + return nil, err + } + + tableComments, err := mysqlTableCommentMap(db, dbName) + if err != nil { + return nil, err + } + var tables []Table for _, tableName := range tableNames { cols, err := mysqlColumns(db, dbName, tableName, fkMap, checkMap, rangeMap) if err != nil { return nil, fmt.Errorf("failed to introspect table %s: %w", tableName, err) } - tables = append(tables, Table{Name: tableName, Columns: cols}) + tables = append(tables, Table{Name: tableName, Columns: cols, Indexes: indexMap[tableName], Comment: tableComments[tableName]}) } return tables, nil @@ -133,7 +143,11 @@ func mysqlColumns(db *sql.DB, dbName, tableName string, fkMap map[string]map[str DATA_TYPE, COLUMN_TYPE, IS_NULLABLE, - COLUMN_KEY + COLUMN_KEY, + COLUMN_DEFAULT, + EXTRA, + GENERATION_EXPRESSION, + COLUMN_COMMENT FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? @@ -145,17 +159,30 @@ func mysqlColumns(db *sql.DB, dbName, tableName string, fkMap map[string]map[str var columns []Column for rows.Next() { - var name, dataType, columnType, isNullable, columnKey string - if err := rows.Scan(&name, &dataType, &columnType, &isNullable, &columnKey); err != nil { + var name, dataType, columnType, isNullable, columnKey, extra, generationExpr, comment string + var defaultValue sql.NullString + if err := rows.Scan(&name, &dataType, &columnType, &isNullable, &columnKey, &defaultValue, &extra, &generationExpr, &comment); err != nil { return nil, err } col := Column{ Name: name, Type: strings.ToLower(dataType), + DDLType: columnType, IsNullable: isNullable == "YES", IsPK: columnKey == "PRI", Unique: columnKey == "UNI", + Comment: comment, + } + if defaultValue.Valid && !strings.Contains(strings.ToLower(extra), "generated") { + col.Default = mysqlDefaultLiteral(defaultValue.String, dataType) + } + if strings.Contains(strings.ToLower(extra), "auto_increment") { + col.AutoIncrement = true + col.Default = "" + } + if strings.Contains(strings.ToLower(extra), "generated") { + col.Generated = mysqlGeneratedExpression(generationExpr) } // Parse enum values from COLUMN_TYPE e.g. enum('a','b','c') @@ -323,3 +350,75 @@ func parseEnumValues(columnType string) []string { } return values } + +func mysqlIndexMap(db *sql.DB, dbName string) (map[string][]Index, error) { + rows, err := db.Query(` + SELECT + TABLE_NAME, + INDEX_NAME, + NON_UNIQUE, + GROUP_CONCAT(COLUMN_NAME ORDER BY SEQ_IN_INDEX SEPARATOR ',') + FROM information_schema.STATISTICS + WHERE TABLE_SCHEMA = ? + AND INDEX_NAME <> 'PRIMARY' + GROUP BY TABLE_NAME, INDEX_NAME, NON_UNIQUE`, dbName) + if err != nil { + return nil, fmt.Errorf("failed to query indexes: %w", err) + } + defer rows.Close() + + m := make(map[string][]Index) + for rows.Next() { + var table, name, columns string + var nonUnique int + if err := rows.Scan(&table, &name, &nonUnique, &columns); err != nil { + return nil, err + } + cols := strings.Split(columns, ",") + if len(cols) == 1 && nonUnique == 0 { + continue + } + m[table] = append(m[table], Index{Name: name, Columns: cols, Unique: nonUnique == 0}) + } + return m, nil +} + +func mysqlTableCommentMap(db *sql.DB, dbName string) (map[string]string, error) { + rows, err := db.Query(` + SELECT TABLE_NAME, TABLE_COMMENT + FROM information_schema.TABLES + WHERE TABLE_SCHEMA = ? + AND TABLE_TYPE = 'BASE TABLE' + AND TABLE_COMMENT <> ''`, dbName) + if err != nil { + return nil, fmt.Errorf("failed to query table comments: %w", err) + } + defer rows.Close() + + m := make(map[string]string) + for rows.Next() { + var table, comment string + if err := rows.Scan(&table, &comment); err != nil { + return nil, err + } + m[table] = comment + } + return m, nil +} + +func mysqlDefaultLiteral(value, dataType string) string { + switch strings.ToLower(dataType) { + case "char", "varchar", "text", "mediumtext", "longtext", "enum", "set", "date", "time", "datetime", "timestamp": + upper := strings.ToUpper(value) + if upper == "CURRENT_TIMESTAMP" || strings.HasSuffix(upper, "()") { + return value + } + return quoteStringLiteral(value) + default: + return value + } +} + +func mysqlGeneratedExpression(expr string) string { + return strings.ReplaceAll(expr, `\'`, `'`) +} diff --git a/internal/db/postgres.go b/internal/db/postgres.go index 32a8271..f304075 100644 --- a/internal/db/postgres.go +++ b/internal/db/postgres.go @@ -54,13 +54,23 @@ func introspectPostgres(db *sql.DB) ([]Table, error) { return nil, err } + indexMap, err := postgresIndexMap(db, uniqueMap) + if err != nil { + return nil, err + } + + tableComments, columnComments, err := postgresCommentMaps(db) + if err != nil { + return nil, err + } + var tables []Table for _, tableName := range tableNames { - cols, err := postgresColumns(db, tableName, fkMap, pkMap, uniqueMap, checkMap, rangeMap) + cols, err := postgresColumns(db, tableName, fkMap, pkMap, uniqueMap, checkMap, rangeMap, columnComments) if err != nil { return nil, fmt.Errorf("failed to introspect table %s: %w", tableName, err) } - tables = append(tables, Table{Name: tableName, Columns: cols}) + tables = append(tables, Table{Name: tableName, Columns: cols, Indexes: indexMap[tableName], Comment: tableComments[tableName]}) } return tables, nil @@ -131,17 +141,24 @@ func postgresPKMap(db *sql.DB) (map[string]map[string]bool, error) { type rangeConstraint struct{ Min, Max int64 } -func postgresColumns(db *sql.DB, tableName string, fkMap map[string]map[string]*ForeignKey, pkMap map[string]map[string]bool, uniqueMap map[string]map[string]bool, checkMap map[string]map[string][]string, rangeMap map[string]map[string]rangeConstraint) ([]Column, error) { +func postgresColumns(db *sql.DB, tableName string, fkMap map[string]map[string]*ForeignKey, pkMap map[string]map[string]bool, uniqueMap map[string]map[string]bool, checkMap map[string]map[string][]string, rangeMap map[string]map[string]rangeConstraint, columnComments map[string]map[string]string) ([]Column, error) { rows, err := db.Query(` SELECT - column_name, - data_type, - udt_name, - is_nullable - FROM information_schema.columns - WHERE table_schema = 'public' - AND table_name = $1 - ORDER BY ordinal_position`, tableName) + c.column_name, + c.data_type, + c.udt_name, + c.is_nullable, + c.column_default, + c.is_generated, + c.generation_expression, + format_type(a.atttypid, a.atttypmod) + FROM information_schema.columns c + JOIN pg_class t ON t.relname = c.table_name + JOIN pg_namespace n ON n.oid = t.relnamespace AND n.nspname = c.table_schema + JOIN pg_attribute a ON a.attrelid = t.oid AND a.attname = c.column_name + WHERE c.table_schema = 'public' + AND c.table_name = $1 + ORDER BY c.ordinal_position`, tableName) if err != nil { return nil, err } @@ -149,8 +166,10 @@ func postgresColumns(db *sql.DB, tableName string, fkMap map[string]map[string]* var columns []Column for rows.Next() { - var name, dataType, udtName, isNullable string - if err := rows.Scan(&name, &dataType, &udtName, &isNullable); err != nil { + var name, dataType, udtName, isNullable, isGenerated, ddlType string + var generationExpr sql.NullString + var defaultValue sql.NullString + if err := rows.Scan(&name, &dataType, &udtName, &isNullable, &defaultValue, &isGenerated, &generationExpr, &ddlType); err != nil { return nil, err } @@ -163,10 +182,22 @@ func postgresColumns(db *sql.DB, tableName string, fkMap map[string]map[string]* col := Column{ Name: name, Type: colType, + DDLType: ddlType, IsNullable: isNullable == "YES", IsPK: pkMap[tableName] != nil && pkMap[tableName][name], Unique: uniqueMap[tableName] != nil && uniqueMap[tableName][name], } + if defaultValue.Valid && isPostgresSerialDefault(defaultValue.String) { + col.AutoIncrement = true + } else if defaultValue.Valid && isGenerated != "ALWAYS" { + col.Default = defaultValue.String + } + if isGenerated == "ALWAYS" { + col.Generated = generationExpr.String + } + if columnComments[tableName] != nil { + col.Comment = columnComments[tableName][name] + } // Resolve enum values for user-defined enum types if dataType == "USER-DEFINED" { @@ -359,3 +390,106 @@ func postgresEnumValues(db *sql.DB, typeName string) ([]string, error) { } return values, nil } + +func isPostgresSerialDefault(value string) bool { + return strings.HasPrefix(value, "nextval(") +} + +func postgresIndexMap(db *sql.DB, uniqueMap map[string]map[string]bool) (map[string][]Index, error) { + rows, err := db.Query(postgresIndexQuery()) + if err != nil { + return nil, fmt.Errorf("failed to query indexes: %w", err) + } + defer rows.Close() + + m := make(map[string][]Index) + for rows.Next() { + var table, name string + var unique bool + var columns string + if err := rows.Scan(&table, &name, &unique, &columns); err != nil { + return nil, err + } + cols := strings.Split(strings.Trim(columns, "{}"), ",") + if len(cols) == 1 && unique && uniqueMap[table] != nil && uniqueMap[table][cols[0]] { + continue + } + m[table] = append(m[table], Index{Name: name, Columns: cols, Unique: unique}) + } + return m, nil +} + +func postgresIndexQuery() string { + return ` + SELECT + t.relname AS table_name, + i.relname AS index_name, + ix.indisunique, + string_agg(a.attname, ',' ORDER BY ord.ordinality) AS columns + FROM pg_index ix + JOIN pg_class t ON t.oid = ix.indrelid + JOIN pg_namespace n ON n.oid = t.relnamespace + JOIN pg_class i ON i.oid = ix.indexrelid + JOIN unnest(ix.indkey) WITH ORDINALITY AS ord(attnum, ordinality) ON true + JOIN pg_attribute a ON a.attrelid = t.oid AND a.attnum = ord.attnum + WHERE n.nspname = 'public' + AND NOT ix.indisprimary + AND ix.indpred IS NULL + AND NOT EXISTS ( + SELECT 1 + FROM unnest(ix.indkey) AS key(attnum) + WHERE key.attnum = 0 + ) + GROUP BY t.relname, i.relname, ix.indisunique` +} + +func postgresCommentMaps(db *sql.DB) (map[string]string, map[string]map[string]string, error) { + tableRows, err := db.Query(` + SELECT c.relname, obj_description(c.oid, 'pg_class') + FROM pg_class c + JOIN pg_namespace n ON n.oid = c.relnamespace + WHERE n.nspname = 'public' + AND c.relkind = 'r' + AND obj_description(c.oid, 'pg_class') IS NOT NULL`) + if err != nil { + return nil, nil, fmt.Errorf("failed to query table comments: %w", err) + } + defer tableRows.Close() + + tableComments := make(map[string]string) + for tableRows.Next() { + var table, comment string + if err := tableRows.Scan(&table, &comment); err != nil { + return nil, nil, err + } + tableComments[table] = comment + } + + columnRows, err := db.Query(` + SELECT c.relname, a.attname, col_description(c.oid, a.attnum) + FROM pg_class c + JOIN pg_namespace n ON n.oid = c.relnamespace + JOIN pg_attribute a ON a.attrelid = c.oid + WHERE n.nspname = 'public' + AND c.relkind = 'r' + AND a.attnum > 0 + AND NOT a.attisdropped + AND col_description(c.oid, a.attnum) IS NOT NULL`) + if err != nil { + return nil, nil, fmt.Errorf("failed to query column comments: %w", err) + } + defer columnRows.Close() + + columnComments := make(map[string]map[string]string) + for columnRows.Next() { + var table, column, comment string + if err := columnRows.Scan(&table, &column, &comment); err != nil { + return nil, nil, err + } + if columnComments[table] == nil { + columnComments[table] = make(map[string]string) + } + columnComments[table][column] = comment + } + return tableComments, columnComments, nil +} diff --git a/internal/db/types.go b/internal/db/types.go index b54cbb3..c22dd3a 100644 --- a/internal/db/types.go +++ b/internal/db/types.go @@ -4,20 +4,27 @@ package db type Table struct { Name string Columns []Column + Indexes []Index + Comment string } // Column represents a column in a database table. type Column struct { - Name string - Type string - IsNullable bool - IsPK bool - FK *ForeignKey - EnumValues []string - Unique bool // column has a single-column UNIQUE constraint - CheckValues []string // values extracted from a CHECK (col IN (...)) constraint - CheckMin *int64 // lower bound from a CHECK (col >= N) or CHECK (col BETWEEN N AND M) constraint - CheckMax *int64 // upper bound from a CHECK (col <= N) constraint + Name string + Type string + DDLType string + IsNullable bool + IsPK bool + FK *ForeignKey + EnumValues []string + Unique bool // column has a single-column UNIQUE constraint + CheckValues []string // values extracted from a CHECK (col IN (...)) constraint + CheckMin *int64 // lower bound from a CHECK (col >= N) or CHECK (col BETWEEN N AND M) constraint + CheckMax *int64 // upper bound from a CHECK (col <= N) constraint + Default string + Generated string + AutoIncrement bool + Comment string } // ForeignKey represents a foreign key reference. @@ -25,3 +32,11 @@ type ForeignKey struct { TableName string ColumnName string } + +// Index represents a non-primary index that should be recreated after tables +// and foreign keys exist. +type Index struct { + Name string + Columns []string + Unique bool +} diff --git a/internal/faker/faker.go b/internal/faker/faker.go index 286c981..eadf793 100644 --- a/internal/faker/faker.go +++ b/internal/faker/faker.go @@ -338,6 +338,9 @@ func generateRow(table schema.Table, tableName string, generatedPKs map[string][ for _, colName := range colNames { col := table.Columns[colName] + if col.Generated { + continue + } val, err := generateValue(col, colName, tableName, generatedPKs, enumVal, enumCol) if err != nil { return nil, fmt.Errorf("column %s: %w", colName, err) diff --git a/internal/faker/faker_test.go b/internal/faker/faker_test.go index 21a3040..9b61e68 100644 --- a/internal/faker/faker_test.go +++ b/internal/faker/faker_test.go @@ -300,6 +300,22 @@ func TestGenerate_noEnumColumns_rowCountUnchanged(t *testing.T) { } } +func TestGenerate_skipsGeneratedColumns(t *testing.T) { + s := &schema.Schema{Tables: map[string]schema.Table{ + "orders": {Columns: map[string]schema.Column{ + "id": {Type: "integer", PK: true}, + "total": {Type: "integer", Generated: true, Faker: "number(1,10)"}, + }}, + }} + rows, err := Generate(s, []string{"orders"}, 1, 0, nil, "pgx") + if err != nil { + t.Fatalf("Generate: %v", err) + } + if _, ok := rows["orders"][0]["total"]; ok { + t.Fatalf("generated column should not be present in insert rows: %#v", rows["orders"][0]) + } +} + // ── generatePK ─────────────────────────────────────────────────────────────── func TestGeneratePK_integerType(t *testing.T) { diff --git a/internal/schema/schema.go b/internal/schema/schema.go index 90ba40f..7d8128e 100644 --- a/internal/schema/schema.go +++ b/internal/schema/schema.go @@ -19,11 +19,12 @@ type Table struct { // Column holds metadata and faker mapping for a single column. type Column struct { - Type string `yaml:"type"` - Faker string `yaml:"faker,omitempty"` - FK string `yaml:"fk,omitempty"` - PK bool `yaml:"pk,omitempty"` - Nullable bool `yaml:"nullable,omitempty"` + Type string `yaml:"type"` + Faker string `yaml:"faker,omitempty"` + FK string `yaml:"fk,omitempty"` + PK bool `yaml:"pk,omitempty"` + Nullable bool `yaml:"nullable,omitempty"` + Generated bool `yaml:"generated,omitempty"` } // Load reads a schema YAML file from disk. diff --git a/internal/web/session.go b/internal/web/session.go index 3d26c80..02c2f8c 100644 --- a/internal/web/session.go +++ b/internal/web/session.go @@ -227,10 +227,11 @@ func (s *Session) Schema(force bool) (*schema.Schema, error) { st := schema.Table{Columns: make(map[string]schema.Column, len(t.Columns))} for _, c := range t.Columns { sc := schema.Column{ - Type: c.Type, - PK: c.IsPK, - Nullable: c.IsNullable, - Faker: faker.MapColumnToFaker(s.DBType, c), + Type: c.Type, + PK: c.IsPK, + Nullable: c.IsNullable, + Generated: c.Generated != "", + Faker: faker.MapColumnToFaker(s.DBType, c), } if c.FK != nil { sc.FK = fmt.Sprintf("%s.%s", c.FK.TableName, c.FK.ColumnName)