From e5e3a2271683ebb25a109990634a26dcf602b5e5 Mon Sep 17 00:00:00 2001 From: Guilherme Raduenz Date: Tue, 27 Jan 2026 11:53:00 -0300 Subject: [PATCH 01/26] removed TableName and ColumnName structs --- schema/column.go | 60 ++++++++++++++---------------------------------- schema/schema.go | 12 +++++----- schema/table.go | 34 ++++----------------------- 3 files changed, 28 insertions(+), 78 deletions(-) diff --git a/schema/column.go b/schema/column.go index 030fdee..40f09c3 100644 --- a/schema/column.go +++ b/schema/column.go @@ -1,27 +1,5 @@ package schema -import ( - "fmt" - "regexp" -) - -var regexValidColumnName = regexp.MustCompile(`^[A-Za-z][A-Za-z0-9_]*$`) - -type ColumnName struct { - columnName string -} - -func NewColumnName(pName string) (*ColumnName, error) { - if !regexValidColumnName.MatchString(pName) { - return nil, fmt.Errorf("invalid column name: %q", pName) - } - return &ColumnName{pName}, nil -} - -func (tn ColumnName) String() string { - return tn.columnName -} - type ColumnType int const ( @@ -44,30 +22,26 @@ type NewColumnParams struct { } type Column struct { - name ColumnName - columnType ColumnType - size uint - pk bool - notNull bool - autoInc bool + name string + columnType ColumnType + size uint + primaryKey bool + notNull bool + autoIncrement bool } -func NewColumn(pParams NewColumnParams) (*Column, error) { - lColumnName, lError := NewColumnName(pParams.Name) - if lError != nil { - return nil, lError +func NewColumn(pParams NewColumnParams) Column { + return Column{ + name: pParams.Name, + columnType: pParams.Type, + size: pParams.Size, + primaryKey: pParams.PrimaryKey, + notNull: pParams.NotNull, + autoIncrement: pParams.AutoIncrement, } - return &Column{ - name: *lColumnName, - columnType: pParams.Type, - size: pParams.Size, - pk: pParams.PrimaryKey, - notNull: pParams.NotNull, - autoInc: pParams.AutoIncrement, - }, nil } -func (c Column) Name() ColumnName { +func (c Column) Name() string { return c.name } @@ -80,7 +54,7 @@ func (c Column) Size() uint { } func (c Column) PrimaryKey() bool { - return c.pk + return c.primaryKey } func (c Column) NotNull() bool { @@ -88,5 +62,5 @@ func (c Column) NotNull() bool { } func (c Column) AutoIncrement() bool { - return c.autoInc + return c.autoIncrement } diff --git a/schema/schema.go b/schema/schema.go index 2d95994..4e9b91a 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -20,13 +20,13 @@ func NewDBSchema(pDB *sql.DB, pTableNamesQuery string) (*DBSchema, error) { return &DBSchema{db: pDB, tableNamesQuery: pTableNamesQuery}, nil } -func (s DBSchema) HasTable(pTableName TableName) bool { +func (s DBSchema) HasTable(pTableName string) bool { lTables, lError := s.fetchTableNames() if lError != nil { return false } return slices.ContainsFunc(lTables, func(t string) bool { - return strings.EqualFold(t, pTableName.String()) + return strings.EqualFold(t, pTableName) }) } @@ -52,18 +52,18 @@ func (s DBSchema) fetchTableNames() ([]string, error) { return lTableNames, lError } -func (s DBSchema) HasColumn(pTableName TableName, pColumnName ColumnName) bool { +func (s DBSchema) HasColumn(pTableName string, pColumnName string) bool { lColumnTypes, lError := s.fetchColumnTypes(pTableName) if lError != nil { return false } return slices.ContainsFunc(lColumnTypes, func(ct *sql.ColumnType) bool { - return strings.EqualFold(ct.Name(), pColumnName.String()) + return strings.EqualFold(ct.Name(), pColumnName) }) } -func (s DBSchema) fetchColumnTypes(pTableName TableName) ([]*sql.ColumnType, error) { - lRows, lError := s.db.Query(fmt.Sprintf("SELECT * FROM %s WHERE 1=0", pTableName.String())) // NOSONAR go:S2077 - Dynamic SQL is controlled and sanitized internally +func (s DBSchema) fetchColumnTypes(pTableName string) ([]*sql.ColumnType, error) { + lRows, lError := s.db.Query(fmt.Sprintf("SELECT * FROM %s WHERE 1=0", pTableName)) // NOSONAR go:S2077 - Dynamic SQL is controlled and sanitized internally if lError != nil { return nil, lError } diff --git a/schema/table.go b/schema/table.go index 3f4eca6..1f46f4f 100644 --- a/schema/table.go +++ b/schema/table.go @@ -2,43 +2,19 @@ package schema import ( "fmt" - "regexp" "slices" "strings" ) -var regexValidTableName = regexp.MustCompile(`^([A-Za-z_][A-Za-z0-9_]*\.)*[A-Za-z_][A-Za-z0-9_]*$`) - -type TableName struct { - tableName string -} - -func NewTableName(pName string) (*TableName, error) { - if !regexValidTableName.MatchString(pName) { - return nil, fmt.Errorf("invalid table name: %q", pName) - } - return &TableName{pName}, nil -} - -func (tn TableName) String() string { - return tn.tableName -} - type Table struct { - name TableName + name string columns []Column } -func NewTable(pName string) (*Table, error) { - lTableName, lError := NewTableName(pName) - if lError != nil { - return nil, lError - } - return &Table{ - name: *lTableName, +func NewTable(pName string) Table { + return Table{ + name: pName, columns: []Column{}, - }, nil -} func (t *Table) AddColumn(pParams NewColumnParams) error { lColumn, lError := NewColumn(pParams) @@ -55,7 +31,7 @@ func (t *Table) AddColumn(pParams NewColumnParams) error { return nil } -func (t Table) Name() TableName { +func (t Table) Name() string { return t.name } From 5059922c8d8d40772b7747f2aab75a8ff1649b3f Mon Sep 17 00:00:00 2001 From: Guilherme Raduenz Date: Tue, 27 Jan 2026 11:53:44 -0300 Subject: [PATCH 02/26] replaced table's AddColumn by AddColumns --- schema/table.go | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/schema/table.go b/schema/table.go index 1f46f4f..b273c9b 100644 --- a/schema/table.go +++ b/schema/table.go @@ -15,20 +15,7 @@ func NewTable(pName string) Table { return Table{ name: pName, columns: []Column{}, - -func (t *Table) AddColumn(pParams NewColumnParams) error { - lColumn, lError := NewColumn(pParams) - if lError != nil { - return lError - } - lColumnAlreadyExists := slices.ContainsFunc(t.columns, func(c Column) bool { - return strings.EqualFold(lColumn.Name().String(), c.Name().String()) - }) - if lColumnAlreadyExists { - return fmt.Errorf("column %q already exists in table %q", lColumn.Name().String(), t.name) } - t.columns = append(t.columns, *lColumn) - return nil } func (t Table) Name() string { @@ -38,3 +25,17 @@ func (t Table) Name() string { func (t Table) Columns() []Column { return t.columns } + +func (t Table) AddColumns(pParams ...NewColumnParams) error { + for _, lColParams := range pParams { + lColumn := NewColumn(lColParams) + lColumnAlreadyExists := slices.ContainsFunc(t.columns, func(c Column) bool { + return strings.EqualFold(lColumn.Name(), c.Name()) + }) + if lColumnAlreadyExists { + return fmt.Errorf("column %q already exists in table %q", lColumn.Name(), t.Name()) + } + t.columns = append(t.columns, lColumn) + } + return nil +} From 51209b5670121aab319454a68c04057da34c9de7 Mon Sep 17 00:00:00 2001 From: Guilherme Raduenz Date: Tue, 27 Jan 2026 11:54:48 -0300 Subject: [PATCH 03/26] modified existing code to support previous changes --- builder.go | 6 +- internal/testutils/fake.go | 313 ++++++++++++------------------- internal/testutils/migrations.go | 90 +++------ migrations/migrations_test.go | 27 +-- migrations/migrator.go | 44 ++--- migrations/migrator_test.go | 16 +- schema/column_test.go | 8 +- schema/schema_test.go | 26 +-- schema/table_test.go | 30 +-- 9 files changed, 196 insertions(+), 364 deletions(-) diff --git a/builder.go b/builder.go index 0df2939..cbf4731 100644 --- a/builder.go +++ b/builder.go @@ -10,9 +10,9 @@ import ( // DDSQLBuilder creates DDL (Data Definition Language) SQL commands for defining schema in DBMS. type DDLSQLBuilder interface { CreateTable(pTable schema.Table) string - DropTable(pTableName schema.TableName) string - AlterTableAddColumn(pTableName schema.TableName, pColumn schema.Column) string - AlterTableDropColumn(pTableName schema.TableName, pColumnName schema.ColumnName) string + DropTable(pTableName string) string + AlterTableAddColumn(pTableName string, pColumn schema.Column) string + AlterTableDropColumn(pTableName string, pColumnName string) string ColumnTypeAsString(pColumnType schema.ColumnType) string } diff --git a/internal/testutils/fake.go b/internal/testutils/fake.go index 44ce3d7..f7967f5 100644 --- a/internal/testutils/fake.go +++ b/internal/testutils/fake.go @@ -7,214 +7,151 @@ import ( "github.com/ordershift/ormshift/schema" ) -func FakeProductAttributeTable(t *testing.T) *schema.Table { - lProductAttributeTable, lError := schema.NewTable("product_attribute") - if !AssertNotNilResultAndNilError(t, lProductAttributeTable, lError, "schema.NewTable") { - return nil - } - lError = lProductAttributeTable.AddColumn(schema.NewColumnParams{ - Name: "product_id", - Type: schema.Integer, - PrimaryKey: true, - NotNull: true, - AutoIncrement: false, - }) - if !AssertNilError(t, lError, "ProductAttributeTable.AddColumn") { - return nil - } - lError = lProductAttributeTable.AddColumn(schema.NewColumnParams{ - Name: "attribute_id", - Type: schema.Integer, - PrimaryKey: true, - NotNull: true, - AutoIncrement: false, - }) - if !AssertNilError(t, lError, "ProductAttributeTable.AddColumn") { - return nil - } - lError = lProductAttributeTable.AddColumn(schema.NewColumnParams{ - Name: "value", - Type: schema.Varchar, - Size: 75, - PrimaryKey: false, - NotNull: false, - AutoIncrement: false, - }) - if !AssertNilError(t, lError, "ProductAttributeTable.AddColumn") { - return nil - } - lError = lProductAttributeTable.AddColumn(schema.NewColumnParams{ - Name: "position", - Type: schema.Integer, - PrimaryKey: false, - NotNull: false, - AutoIncrement: false, - }) - if !AssertNilError(t, lError, "ProductAttributeTable.AddColumn") { - return nil +func FakeProductAttributeTable(t *testing.T) schema.Table { + lProductAttributeTable := schema.NewTable("product_attribute") + lError := lProductAttributeTable.AddColumns( + schema.NewColumnParams{ + Name: "product_id", + Type: schema.Integer, + PrimaryKey: true, + NotNull: true, + }, + schema.NewColumnParams{ + Name: "attribute_id", + Type: schema.Integer, + PrimaryKey: true, + NotNull: true, + }, + schema.NewColumnParams{ + Name: "value", + Type: schema.Varchar, + Size: 75, + }, + schema.NewColumnParams{ + Name: "position", + Type: schema.Integer, + }, + ) + if !AssertNilError(t, lError, "ProductAttributeTable.AddColumns") { + panic(lError) } return lProductAttributeTable } -func FakeUserTable(t *testing.T) *schema.Table { - lUserTable, lError := schema.NewTable("user") - if !AssertNotNilResultAndNilError(t, lUserTable, lError, "schema.NewTable") { - return nil - } - lError = lUserTable.AddColumn(schema.NewColumnParams{ - Name: "id", - Type: schema.Integer, - PrimaryKey: true, - NotNull: true, - AutoIncrement: true, - }) - if !AssertNilError(t, lError, "UserTable.AddColumn") { - return nil - } - lError = lUserTable.AddColumn(schema.NewColumnParams{ - Name: "email", - Type: schema.Varchar, - Size: 80, - PrimaryKey: true, - NotNull: true, - AutoIncrement: false, - }) +func FakeUserTable(t *testing.T) schema.Table { + lUserTable := schema.NewTable("user") + lError := lUserTable.AddColumns( + schema.NewColumnParams{ + Name: "id", + Type: schema.Integer, + PrimaryKey: true, + NotNull: true, + AutoIncrement: true, + }, + schema.NewColumnParams{ + Name: "email", + Type: schema.Varchar, + Size: 80, + PrimaryKey: true, + NotNull: true, + AutoIncrement: false, + }, + schema.NewColumnParams{ + Name: "name", + Type: schema.Varchar, + Size: 50, + PrimaryKey: false, + NotNull: true, + AutoIncrement: false, + }, + schema.NewColumnParams{ + Name: "password_hash", + Type: schema.Varchar, + Size: 256, + PrimaryKey: false, + NotNull: false, + AutoIncrement: false, + }, + schema.NewColumnParams{ + Name: "active", + Type: schema.Boolean, + PrimaryKey: false, + NotNull: false, + AutoIncrement: false, + }, + schema.NewColumnParams{ + Name: "created_at", + Type: schema.DateTime, + PrimaryKey: false, + NotNull: false, + AutoIncrement: false, + }, + schema.NewColumnParams{ + Name: "user_master", + Type: schema.Integer, + PrimaryKey: false, + NotNull: false, + AutoIncrement: false, + }, + schema.NewColumnParams{ + Name: "master_user_id", + Type: schema.Integer, + PrimaryKey: false, + NotNull: false, + AutoIncrement: false, + }, + schema.NewColumnParams{ + Name: "licence_price", + Type: schema.Monetary, + PrimaryKey: false, + NotNull: false, + AutoIncrement: false, + }, + schema.NewColumnParams{ + Name: "relevance", + Type: schema.Decimal, + PrimaryKey: false, + NotNull: false, + AutoIncrement: false, + }, + schema.NewColumnParams{ + Name: "photo", + Type: schema.Binary, + PrimaryKey: false, + NotNull: false, + AutoIncrement: false, + }, + schema.NewColumnParams{ + Name: "any", + Type: schema.ColumnType(-1), + PrimaryKey: false, + NotNull: false, + AutoIncrement: false, + }, + ) if !AssertNilError(t, lError, "UserTable.AddColumn") { - return nil - } - lError = lUserTable.AddColumn(schema.NewColumnParams{ - Name: "name", - Type: schema.Varchar, - Size: 50, - PrimaryKey: false, - NotNull: true, - AutoIncrement: false, - }) - if !AssertNilError(t, lError, "UserTable.AddColumn") { - return nil - } - lError = lUserTable.AddColumn(schema.NewColumnParams{ - Name: "password_hash", - Type: schema.Varchar, - Size: 256, - PrimaryKey: false, - NotNull: false, - AutoIncrement: false, - }) - if !AssertNilError(t, lError, "UserTable.AddColumn") { - return nil - } - lError = lUserTable.AddColumn(schema.NewColumnParams{ - Name: "active", - Type: schema.Boolean, - PrimaryKey: false, - NotNull: false, - AutoIncrement: false, - }) - if !AssertNilError(t, lError, "UserTable.AddColumn") { - return nil - } - lError = lUserTable.AddColumn(schema.NewColumnParams{ - Name: "created_at", - Type: schema.DateTime, - PrimaryKey: false, - NotNull: false, - AutoIncrement: false, - }) - if !AssertNilError(t, lError, "UserTable.AddColumn") { - return nil - } - lError = lUserTable.AddColumn(schema.NewColumnParams{ - Name: "user_master", - Type: schema.Integer, - PrimaryKey: false, - NotNull: false, - AutoIncrement: false, - }) - if !AssertNilError(t, lError, "UserTable.AddColumn") { - return nil - } - lError = lUserTable.AddColumn(schema.NewColumnParams{ - Name: "master_user_id", - Type: schema.Integer, - PrimaryKey: false, - NotNull: false, - AutoIncrement: false, - }) - if !AssertNilError(t, lError, "UserTable.AddColumn") { - return nil - } - lError = lUserTable.AddColumn(schema.NewColumnParams{ - Name: "licence_price", - Type: schema.Monetary, - PrimaryKey: false, - NotNull: false, - AutoIncrement: false, - }) - if !AssertNilError(t, lError, "UserTable.AddColumn") { - return nil - } - lError = lUserTable.AddColumn(schema.NewColumnParams{ - Name: "relevance", - Type: schema.Decimal, - PrimaryKey: false, - NotNull: false, - AutoIncrement: false, - }) - if !AssertNilError(t, lError, "UserTable.AddColumn") { - return nil - } - lError = lUserTable.AddColumn(schema.NewColumnParams{ - Name: "photo", - Type: schema.Binary, - PrimaryKey: false, - NotNull: false, - AutoIncrement: false, - }) - if !AssertNilError(t, lError, "UserTable.AddColumn") { - return nil - } - lError = lUserTable.AddColumn(schema.NewColumnParams{ - Name: "any", - Type: schema.ColumnType(-1), - PrimaryKey: false, - NotNull: false, - AutoIncrement: false, - }) - if !AssertNilError(t, lError, "UserTable.AddColumn") { - return nil + panic(lError) } return lUserTable } -func FakeUserTableName(t *testing.T) *schema.TableName { - lUserTableName, lError := schema.NewTableName("user") - if !AssertNotNilResultAndNilError(t, lUserTableName, lError, "schema.NewTableName") { - return nil - } - return lUserTableName +func FakeUserTableName(t *testing.T) string { + return "user" } -func FakeUpdatedAtColumn(t *testing.T) *schema.Column { - lUpdatedAtColumn, lError := schema.NewColumn(schema.NewColumnParams{ +func FakeUpdatedAtColumn(t *testing.T) schema.Column { + lUpdatedAtColumn := schema.NewColumn(schema.NewColumnParams{ Name: "updated_at", Type: schema.DateTime, PrimaryKey: false, NotNull: false, AutoIncrement: false, }) - if !AssertNotNilResultAndNilError(t, lUpdatedAtColumn, lError, "schema.NewColumn") { - return nil - } return lUpdatedAtColumn } -func FakeUpdatedAtColumnName(t *testing.T) *schema.ColumnName { - lUpdatedAtColumnName, lError := schema.NewColumnName("updated_at") - if !AssertNotNilResultAndNilError(t, lUpdatedAtColumnName, lError, "schema.NewColumnName") { - return nil - } - return lUpdatedAtColumnName +func FakeUpdatedAtColumnName(t *testing.T) string { + return "updated_at" } func FakeInteroperateSQLCommandWithNamedArgsFunc(command string, namedArgs ...sql.NamedArg) (string, []any) { diff --git a/internal/testutils/migrations.go b/internal/testutils/migrations.go index 8382f9d..c8b1b0d 100644 --- a/internal/testutils/migrations.go +++ b/internal/testutils/migrations.go @@ -11,131 +11,97 @@ import ( type M001_Create_Table_User struct{} func (m M001_Create_Table_User) Up(pMigrator *migrations.Migrator) error { - lUserTable, lError := schema.NewTable("user") - if lError != nil { - return lError - } + lUserTable := schema.NewTable("user") if pMigrator.Database().DBSchema().HasTable(lUserTable.Name()) { return nil } - columns := []schema.NewColumnParams{ - { + + lError := lUserTable.AddColumns( + schema.NewColumnParams{ Name: "id", Type: schema.Integer, AutoIncrement: true, PrimaryKey: true, NotNull: true, }, - { + schema.NewColumnParams{ Name: "name", Type: schema.Varchar, Size: 50, PrimaryKey: false, NotNull: false, }, - { + schema.NewColumnParams{ Name: "email", Type: schema.Varchar, Size: 120, PrimaryKey: false, NotNull: false, }, - { + schema.NewColumnParams{ Name: "active", Type: schema.Boolean, PrimaryKey: false, NotNull: false, }, - { + schema.NewColumnParams{ Name: "ammount", Type: schema.Monetary, PrimaryKey: false, NotNull: false, }, - { + schema.NewColumnParams{ Name: "percent", Type: schema.Decimal, PrimaryKey: false, NotNull: false, }, - { + schema.NewColumnParams{ Name: "photo", Type: schema.Binary, PrimaryKey: false, NotNull: false, }, - } - - for _, col := range columns { - if err := lUserTable.AddColumn(col); err != nil { - return err - } - } + ) - _, lError = pMigrator.Database().SQLExecutor().Exec(pMigrator.Database().SQLBuilder().CreateTable(*lUserTable)) - if lError != nil { - return lError - } - return nil + _, lError = pMigrator.Database().SQLExecutor().Exec(pMigrator.Database().SQLBuilder().CreateTable(lUserTable)) + return lError } func (m M001_Create_Table_User) Down(pMigrator *migrations.Migrator) error { - lUserTableName, lError := schema.NewTableName("user") - if lError != nil { - return lError - } - if !pMigrator.Database().DBSchema().HasTable(*lUserTableName) { + lUserTableName := "user" + if !pMigrator.Database().DBSchema().HasTable(lUserTableName) { return nil } - _, lError = pMigrator.Database().SQLExecutor().Exec(pMigrator.Database().SQLBuilder().DropTable(*lUserTableName)) - if lError != nil { - return lError - } - return nil + + _, lError := pMigrator.Database().SQLExecutor().Exec(pMigrator.Database().SQLBuilder().DropTable(lUserTableName)) + return lError } // M002_Alter_Table_User_Add_Column_UpdatedAt adds the "updated_at" column to the "user" table. type M002_Alter_Table_User_Add_Column_UpdatedAt struct{} func (m M002_Alter_Table_User_Add_Column_UpdatedAt) Up(pMigrator *migrations.Migrator) error { - lUserTableName, lError := schema.NewTableName("user") - if lError != nil { - return lError - } - lUpdatedAtColumn, lError := schema.NewColumn(schema.NewColumnParams{ + lUserTableName := "user" + lUpdatedAtColumn := schema.NewColumn(schema.NewColumnParams{ Name: "updated_at", Type: schema.DateTime, }) - if lError != nil { - return lError - } - if pMigrator.Database().DBSchema().HasColumn(*lUserTableName, lUpdatedAtColumn.Name()) { + if pMigrator.Database().DBSchema().HasColumn(lUserTableName, lUpdatedAtColumn.Name()) { return nil } - _, lError = pMigrator.Database().SQLExecutor().Exec(pMigrator.Database().SQLBuilder().AlterTableAddColumn(*lUserTableName, *lUpdatedAtColumn)) - if lError != nil { - return lError - } - return nil + _, lError := pMigrator.Database().SQLExecutor().Exec(pMigrator.Database().SQLBuilder().AlterTableAddColumn(lUserTableName, lUpdatedAtColumn)) + return lError } func (m M002_Alter_Table_User_Add_Column_UpdatedAt) Down(pMigrator *migrations.Migrator) error { - lUserTableName, lError := schema.NewTableName("user") - if lError != nil { - return lError - } - lUpdatedAtColumnName, lError := schema.NewColumnName("updated_at") - if lError != nil { - return lError - } - if !pMigrator.Database().DBSchema().HasColumn(*lUserTableName, *lUpdatedAtColumnName) { + lUserTableName := "user" + lUpdatedAtColumnName := "updated_at" + if !pMigrator.Database().DBSchema().HasColumn(lUserTableName, lUpdatedAtColumnName) { return nil } - _, lError = pMigrator.Database().SQLExecutor().Exec(pMigrator.Database().SQLBuilder().AlterTableDropColumn(*lUserTableName, *lUpdatedAtColumnName)) - if lError != nil { - return lError - } - return nil + _, lError := pMigrator.Database().SQLExecutor().Exec(pMigrator.Database().SQLBuilder().AlterTableDropColumn(lUserTableName, lUpdatedAtColumnName)) + return lError } // M003_Bad_Migration_Fails_To_Apply is a migration that always fails to apply. diff --git a/migrations/migrations_test.go b/migrations/migrations_test.go index f19f974..71c91e2 100644 --- a/migrations/migrations_test.go +++ b/migrations/migrations_test.go @@ -7,7 +7,6 @@ import ( "github.com/ordershift/ormshift/dialects/sqlite" "github.com/ordershift/ormshift/internal/testutils" "github.com/ordershift/ormshift/migrations" - "github.com/ordershift/ormshift/schema" ) func TestMigrate(t *testing.T) { @@ -27,16 +26,10 @@ func TestMigrate(t *testing.T) { if !testutils.AssertNotNilResultAndNilError(t, lMigrator, lError, "migrations.Migrate") { return } - lUserTableName, lError := schema.NewTableName("user") - if !testutils.AssertNilError(t, lError, "migrations.NewTableName") { - return - } - lUpdatedAtColumnName, lError := schema.NewColumnName("updated_at") - if !testutils.AssertNilError(t, lError, "migrations.NewColumnName") { - return - } - testutils.AssertEqualWithLabel(t, true, lDB.DBSchema().HasColumn(*lUserTableName, *lUpdatedAtColumnName), "Migrator.DBSchema.HasColumn[user.updated_at]") - testutils.AssertEqualWithLabel(t, 2, len(lMigrator.AppliedMigrations()), "len(Migrator.AppliedMigrationNames)") + lUserTableName := "user" + lUpdatedAtColumnName := "updated_at" + testutils.AssertEqualWithLabel(t, true, lDB.DBSchema().HasColumn(lUserTableName, lUpdatedAtColumnName), "Migrator.DBSchema.HasColumn[user.updated_at]") + testutils.AssertEqualWithLabel(t, 2, len(lMigrator.AppliedMigrations()), "len(Migrator.AppliedMigrations)") } func TestMigrateTwice(t *testing.T) { @@ -67,15 +60,9 @@ func TestMigrateTwice(t *testing.T) { return } - lUserTableName, lError := schema.NewTableName("user") - if !testutils.AssertNilError(t, lError, "migrations.NewTableName") { - return - } - lUpdatedAtColumnName, lError := schema.NewColumnName("updated_at") - if !testutils.AssertNilError(t, lError, "migrations.NewColumnName") { - return - } - testutils.AssertEqualWithLabel(t, true, lDB.DBSchema().HasColumn(*lUserTableName, *lUpdatedAtColumnName), "Migrator.DBSchema.HasColumn[user.updated_at]") + lUserTableName := "user" + lUpdatedAtColumnName := "updated_at" + testutils.AssertEqualWithLabel(t, true, lDB.DBSchema().HasColumn(lUserTableName, lUpdatedAtColumnName), "Migrator.DBSchema.HasColumn[user.updated_at]") testutils.AssertEqualWithLabel(t, 2, len(lMigrator.AppliedMigrations()), "len(Migrator.AppliedMigrations)") } diff --git a/migrations/migrator.go b/migrations/migrator.go index 0972f7d..b6cf48c 100644 --- a/migrations/migrator.go +++ b/migrations/migrator.go @@ -165,33 +165,25 @@ func getAppliedMigrationNames(pDatabase ormshift.Database, pConfig MigratorConfi } func ensureMigrationsTableExists(pDatabase ormshift.Database, pConfig MigratorConfig) error { - lMigrationsTable, lError := schema.NewTable(pConfig.tableName) - if lError != nil { - return lError + lMigrationsTable := schema.NewTable(pConfig.tableName) + if pDatabase.DBSchema().HasTable(lMigrationsTable.Name()) { + return nil } - if !pDatabase.DBSchema().HasTable(lMigrationsTable.Name()) { - columns := []schema.NewColumnParams{ - { - Name: pConfig.migrationNameColumn, - Type: schema.Varchar, - Size: pConfig.migrationNameMaxLength, - PrimaryKey: true, - NotNull: true, - }, - { - Name: pConfig.appliedAtColumn, - Type: schema.DateTime, - NotNull: true, - }, - } - - for _, col := range columns { - if err := lMigrationsTable.AddColumn(col); err != nil { - return err - } - } + lMigrationsTable.AddColumns( + schema.NewColumnParams{ + Name: pConfig.migrationNameColumn, + Type: schema.Varchar, + Size: pConfig.migrationNameMaxLength, + PrimaryKey: true, + NotNull: true, + }, + schema.NewColumnParams{ + Name: pConfig.appliedAtColumn, + Type: schema.DateTime, + NotNull: true, + }, + ) - _, lError = pDatabase.SQLExecutor().Exec(pDatabase.SQLBuilder().CreateTable(*lMigrationsTable)) // NOSONAR go:S2077 - Dynamic SQL is controlled and sanitized internally - } + _, lError := pDatabase.SQLExecutor().Exec(pDatabase.SQLBuilder().CreateTable(lMigrationsTable)) // NOSONAR go:S2077 - Dynamic SQL is controlled and sanitized internally return lError } diff --git a/migrations/migrator_test.go b/migrations/migrator_test.go index f52d5bc..68589e2 100644 --- a/migrations/migrator_test.go +++ b/migrations/migrator_test.go @@ -8,7 +8,6 @@ import ( "github.com/ordershift/ormshift/dialects/sqlite" "github.com/ordershift/ormshift/internal/testutils" "github.com/ordershift/ormshift/migrations" - "github.com/ordershift/ormshift/schema" ) func TestNewMigratorWhenDatabaseIsNil(t *testing.T) { @@ -70,21 +69,16 @@ func TestRevertLastAppliedMigration(t *testing.T) { return } - lUserTableName, lError := schema.NewTableName("user") - if !testutils.AssertNilError(t, lError, "migrations.NewTableName") { - return - } - testutils.AssertEqualWithLabel(t, true, lDB.DBSchema().HasTable(*lUserTableName), "Migrator.DBSchema.HasTable[user]") + lUserTableName := "user" + testutils.AssertEqualWithLabel(t, true, lDB.DBSchema().HasTable(lUserTableName), "Migrator.DBSchema.HasTable[user]") lError = lMigrator.RevertLastAppliedMigration() if !testutils.AssertNilError(t, lError, "Migrator.RevertLastAppliedMigration") { return } - lUpdatedAtColumnName, lError := schema.NewColumnName("updated_at") - if !testutils.AssertNilError(t, lError, "migrations.NewColumnName") { - return - } - testutils.AssertEqualWithLabel(t, false, lDB.DBSchema().HasColumn(*lUserTableName, *lUpdatedAtColumnName), "Migrator.DBSchema.HasColumn[user.updated_at]") + + lUpdatedAtColumnName := "updated_at" + testutils.AssertEqualWithLabel(t, false, lDB.DBSchema().HasColumn(lUserTableName, lUpdatedAtColumnName), "Migrator.DBSchema.HasColumn[user.updated_at]") } func TestRevertLastAppliedMigrationWhenNoMigrationsApplied(t *testing.T) { diff --git a/schema/column_test.go b/schema/column_test.go index f5e5284..0a634a0 100644 --- a/schema/column_test.go +++ b/schema/column_test.go @@ -8,11 +8,9 @@ import ( ) func TestColumn(t *testing.T) { - lColumn, lError := schema.NewColumn(schema.NewColumnParams{Name: "id", Type: schema.Integer, NotNull: true, PrimaryKey: true, AutoIncrement: true}) - if !testutils.AssertNilError(t, lError, "schema.NewColumn") { - return - } - testutils.AssertEqualWithLabel(t, "id", lColumn.Name().String(), "Column.Name") + lColumn := schema.NewColumn(schema.NewColumnParams{Name: "id", Type: schema.Integer, NotNull: true, PrimaryKey: true, AutoIncrement: true}) + + testutils.AssertEqualWithLabel(t, "id", lColumn.Name(), "Column.Name") testutils.AssertEqualWithLabel(t, schema.Integer, lColumn.Type(), "Column.Type") testutils.AssertEqualWithLabel(t, uint(0), lColumn.Size(), "Column.Size") testutils.AssertEqualWithLabel(t, true, lColumn.PrimaryKey(), "Column.IsPrimaryKey") diff --git a/schema/schema_test.go b/schema/schema_test.go index 2a30e66..df33b38 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -39,11 +39,8 @@ func TestHasColumn(t *testing.T) { defer func() { _ = lDB.Close() }() lProductAttributeTable := testutils.FakeProductAttributeTable(t) - if lProductAttributeTable == nil { - return - } - _, lError = lDB.SQLExecutor().Exec(sqlite.Driver().SQLBuilder().CreateTable(*lProductAttributeTable)) + _, lError = lDB.SQLExecutor().Exec(sqlite.Driver().SQLBuilder().CreateTable(lProductAttributeTable)) if !testutils.AssertNilError(t, lError, "DB.Exec") { return } @@ -53,17 +50,10 @@ func TestHasColumn(t *testing.T) { for _, lColumn := range lProductAttributeTable.Columns() { testutils.AssertEqualWithLabel(t, true, lDBSchema.HasColumn(lProductAttributeTable.Name(), lColumn.Name()), "DBSchema.HasColumn") } - lAnyTableName, lError := schema.NewTableName("any_table") - if !testutils.AssertNotNilResultAndNilError(t, lAnyTableName, lError, "ormshift.NewTableName") { - return - } - testutils.AssertEqualWithLabel(t, false, lDBSchema.HasTable(*lAnyTableName), "DBSchema.HasTable") - lAnyColumnName, lError := schema.NewColumnName("any_col") - if !testutils.AssertNotNilResultAndNilError(t, lAnyColumnName, lError, "ormshift.NewTableName") { - return - } - testutils.AssertEqualWithLabel(t, false, lDBSchema.HasColumn(lProductAttributeTable.Name(), *lAnyColumnName), "DBSchema.HasColumn") - testutils.AssertEqualWithLabel(t, false, lDBSchema.HasColumn(*lAnyTableName, *lAnyColumnName), "DBSchema.HasColumn") + lAnyTableName := "any_table" + lAnyColumnName := "any_col" + testutils.AssertEqualWithLabel(t, false, lDBSchema.HasColumn(lProductAttributeTable.Name(), lAnyColumnName), "DBSchema.HasColumn") + testutils.AssertEqualWithLabel(t, false, lDBSchema.HasColumn(lAnyTableName, lAnyColumnName), "DBSchema.HasColumn") } func TestHasTableReturnsFalseWhenDatabaseIsInvalid(t *testing.T) { @@ -75,12 +65,8 @@ func TestHasTableReturnsFalseWhenDatabaseIsInvalid(t *testing.T) { defer func() { _ = lDB.Close() }() lProductAttributeTable := testutils.FakeProductAttributeTable(t) - if lProductAttributeTable == nil { - _ = lDB.Close() - return - } - _, lError = lDB.SQLExecutor().Exec(sqlite.Driver().SQLBuilder().CreateTable(*lProductAttributeTable)) + _, lError = lDB.SQLExecutor().Exec(sqlite.Driver().SQLBuilder().CreateTable(lProductAttributeTable)) if !testutils.AssertNilError(t, lError, "DB.Exec") { _ = lDB.Close() return diff --git a/schema/table_test.go b/schema/table_test.go index 7ed323a..51e650a 100644 --- a/schema/table_test.go +++ b/schema/table_test.go @@ -8,37 +8,9 @@ import ( "github.com/ordershift/ormshift/schema" ) -func TestNewTableFailsWithInvalidName(t *testing.T) { - lInvalidTableName := "123456-table" - lTable, lError := schema.NewTable(lInvalidTableName) - if !testutils.AssertNilResultAndNotNilError(t, lTable, lError, "schema.NewTable") { - return - } - testutils.AssertErrorMessage(t, fmt.Sprintf("invalid table name: %q", lInvalidTableName), lError, "schema.NewTable") -} - -func TestAddColumnFailsWithInvalidName(t *testing.T) { - lProductAttributeTable := testutils.FakeProductAttributeTable(t) - if lProductAttributeTable == nil { - return - } - lInvalidColumnName := "123456-column" - lError := lProductAttributeTable.AddColumn(schema.NewColumnParams{ - Name: lInvalidColumnName, - Type: schema.Integer, - }) - if !testutils.AssertNotNilError(t, lError, "Table.AddColumn") { - return - } - testutils.AssertErrorMessage(t, fmt.Sprintf("invalid column name: %q", lInvalidColumnName), lError, "Table.AddColumn") -} - func TestAddColumnFailsWhenAlreadyExists(t *testing.T) { lProductAttributeTable := testutils.FakeProductAttributeTable(t) - if lProductAttributeTable == nil { - return - } - lError := lProductAttributeTable.AddColumn(schema.NewColumnParams{ + lError := lProductAttributeTable.AddColumns(schema.NewColumnParams{ Name: "value", Type: schema.Integer, }) From 7f5539815755ee02552bbc3815f589e3e29954ef Mon Sep 17 00:00:00 2001 From: Guilherme Raduenz Date: Tue, 27 Jan 2026 11:55:50 -0300 Subject: [PATCH 04/26] added quoting mechanism to identifiers (tables/columns) when building queries to prevent sql injection --- internal/builder_generic.go | 43 ++++++++++++++++++--------- internal/builder_generic_test.go | 50 +++++++++++++++----------------- 2 files changed, 52 insertions(+), 41 deletions(-) diff --git a/internal/builder_generic.go b/internal/builder_generic.go index 858428d..88b1f5f 100644 --- a/internal/builder_generic.go +++ b/internal/builder_generic.go @@ -11,20 +11,35 @@ import ( type ColumnDefinitionFunc func(schema.Column) string +type QuoteIdentifierFunc func(string) string + type InteroperateSQLCommandWithNamedArgsFunc func(pSQLCommand string, pNamedArgs ...sql.NamedArg) (string, []any) type genericSQLBuilder struct { ColumnDefinitionFunc ColumnDefinitionFunc InteroperateSQLCommandWithNamedArgsFunc InteroperateSQLCommandWithNamedArgsFunc + QuoteIdentifierFunc QuoteIdentifierFunc } -func NewGenericSQLBuilder(pColumnDefinitionFunc ColumnDefinitionFunc, pInteroperateSQLCommandWithNamedArgsFunc InteroperateSQLCommandWithNamedArgsFunc) ormshift.SQLBuilder { +func NewGenericSQLBuilder( + pColumnDefinitionFunc ColumnDefinitionFunc, + pQuoteIdentifierFunc QuoteIdentifierFunc, + pInteroperateSQLCommandWithNamedArgsFunc InteroperateSQLCommandWithNamedArgsFunc, +) ormshift.SQLBuilder { return genericSQLBuilder{ ColumnDefinitionFunc: pColumnDefinitionFunc, + QuoteIdentifierFunc: pQuoteIdentifierFunc, InteroperateSQLCommandWithNamedArgsFunc: pInteroperateSQLCommandWithNamedArgsFunc, } } +func (sb genericSQLBuilder) quoteIdentifier(pIdentifier string) string { + if sb.QuoteIdentifierFunc != nil { + return sb.QuoteIdentifierFunc(pIdentifier) + } + return pIdentifier +} + func (sb genericSQLBuilder) CreateTable(pTable schema.Table) string { lColumns := "" lPKColumns := "" @@ -38,7 +53,7 @@ func (sb genericSQLBuilder) CreateTable(pTable schema.Table) string { if lPKColumns != "" { lPKColumns += "," } - lPKColumns += lColumn.Name().String() + lPKColumns += sb.quoteIdentifier(lColumn.Name()) } } @@ -48,19 +63,19 @@ func (sb genericSQLBuilder) CreateTable(pTable schema.Table) string { } lColumns += fmt.Sprintf("PRIMARY KEY (%s)", lPKColumns) } - return fmt.Sprintf("CREATE TABLE %s (%s);", pTable.Name().String(), lColumns) + return fmt.Sprintf("CREATE TABLE %s (%s);", sb.quoteIdentifier(pTable.Name()), lColumns) } -func (sb genericSQLBuilder) DropTable(pTableName schema.TableName) string { - return fmt.Sprintf("DROP TABLE %s;", pTableName.String()) +func (sb genericSQLBuilder) DropTable(pTableName string) string { + return fmt.Sprintf("DROP TABLE %s;", sb.quoteIdentifier(pTableName)) } -func (sb genericSQLBuilder) AlterTableAddColumn(pTableName schema.TableName, pColumn schema.Column) string { - return fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s;", pTableName.String(), sb.columnDefinition(pColumn)) +func (sb genericSQLBuilder) AlterTableAddColumn(pTableName string, pColumn schema.Column) string { + return fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s;", sb.quoteIdentifier(pTableName), sb.columnDefinition(pColumn)) } -func (sb genericSQLBuilder) AlterTableDropColumn(pTableName schema.TableName, pColumnName schema.ColumnName) string { - return fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s;", pTableName.String(), pColumnName.String()) +func (sb genericSQLBuilder) AlterTableDropColumn(pTableName string, pColumnName string) string { + return fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s;", sb.quoteIdentifier(pTableName), sb.quoteIdentifier(pColumnName)) } func (sb genericSQLBuilder) ColumnTypeAsString(pColumnType schema.ColumnType) string { @@ -72,11 +87,11 @@ func (sb genericSQLBuilder) columnDefinition(pColumn schema.Column) string { if sb.ColumnDefinitionFunc != nil { return sb.ColumnDefinitionFunc(pColumn) } - return fmt.Sprintf("%s %s", pColumn.Name().String(), sb.ColumnTypeAsString(pColumn.Type())) + return fmt.Sprintf("%s %s", sb.quoteIdentifier(pColumn.Name()), sb.ColumnTypeAsString(pColumn.Type())) } func (sb genericSQLBuilder) Insert(pTableName string, pColumns []string) string { - return fmt.Sprintf("insert into %s (%s) values (%s)", pTableName, sb.columnsList(pColumns), sb.namesList(pColumns)) + return fmt.Sprintf("insert into %s (%s) values (%s)", sb.quoteIdentifier(pTableName), sb.columnsList(pColumns), sb.namesList(pColumns)) } func (sb genericSQLBuilder) InsertWithValues(pTableName string, pColumnsValues ormshift.ColumnsValues) (string, []any) { @@ -86,7 +101,7 @@ func (sb genericSQLBuilder) InsertWithValues(pTableName string, pColumnsValues o } func (sb genericSQLBuilder) Update(pTableName string, pColumns, pColumnsWhere []string) string { - lUpdate := fmt.Sprintf("update %s set %s ", pTableName, sb.columnEqualNameList(pColumns, ",")) + lUpdate := fmt.Sprintf("update %s set %s ", sb.quoteIdentifier(pTableName), sb.columnEqualNameList(pColumns, ",")) if len(pColumnsWhere) > 0 { lUpdate += fmt.Sprintf("where %s", sb.columnEqualNameList(pColumnsWhere, " and ")) // NOSONAR go:S1192 - duplicate tradeoff accepted } @@ -100,7 +115,7 @@ func (sb genericSQLBuilder) UpdateWithValues(pTableName string, pColumns, pColum } func (sb genericSQLBuilder) Delete(pTableName string, pColumnsWhere []string) string { - lDelete := fmt.Sprintf("delete from %s ", pTableName) + lDelete := fmt.Sprintf("delete from %s ", sb.quoteIdentifier(pTableName)) if len(pColumnsWhere) > 0 { lDelete += fmt.Sprintf("where %s", sb.columnEqualNameList(pColumnsWhere, " and ")) // NOSONAR go:S1192 - duplicate tradeoff accepted } @@ -114,7 +129,7 @@ func (sb genericSQLBuilder) DeleteWithValues(pTableName string, pWhereColumnsVal } func (sb genericSQLBuilder) Select(pTableName string, pColumns, pColumnsWhere []string) string { - lUpdate := fmt.Sprintf("select %s from %s ", sb.columnsList(pColumns), pTableName) + lUpdate := fmt.Sprintf("select %s from %s ", sb.columnsList(pColumns), sb.quoteIdentifier(pTableName)) if len(pColumnsWhere) > 0 { lUpdate += fmt.Sprintf("where %s", sb.columnEqualNameList(pColumnsWhere, " and ")) // NOSONAR go:S1192 - duplicate tradeoff accepted } diff --git a/internal/builder_generic_test.go b/internal/builder_generic_test.go index 2204821..50589e2 100644 --- a/internal/builder_generic_test.go +++ b/internal/builder_generic_test.go @@ -11,52 +11,52 @@ import ( ) func TestCreateTable(t *testing.T) { - lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil) + lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil, nil) lUserTable := testutils.FakeUserTable(t) lExpectedSQL := "CREATE TABLE user (id <>,email <>,name <>,password_hash <>," + "active <>,created_at <>,user_master <>,master_user_id <>," + "licence_price <>,relevance <>,photo <>,any <>,PRIMARY KEY (id,email));" - lReturnedSQL := lSQLBuilder.CreateTable(*lUserTable) + lReturnedSQL := lSQLBuilder.CreateTable(lUserTable) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.CreateTable") lProductAttributeTable := testutils.FakeProductAttributeTable(t) lExpectedSQL = "CREATE TABLE product_attribute (product_id <>,attribute_id <>,value <>,position <>,PRIMARY KEY (product_id,attribute_id));" - lReturnedSQL = lSQLBuilder.CreateTable(*lProductAttributeTable) + lReturnedSQL = lSQLBuilder.CreateTable(lProductAttributeTable) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.CreateTable") } func TestDropTable(t *testing.T) { - lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil) + lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil, nil) lUserTableName := testutils.FakeUserTableName(t) lExpectedSQL := "DROP TABLE user;" - lReturnedSQL := lSQLBuilder.DropTable(*lUserTableName) + lReturnedSQL := lSQLBuilder.DropTable(lUserTableName) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.DropTable") } func TestAlterTableAddColumn(t *testing.T) { - lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil) + lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil, nil) lUserTableName := testutils.FakeUserTableName(t) lUpdatedAtColumn := testutils.FakeUpdatedAtColumn(t) lExpectedSQL := "ALTER TABLE user ADD COLUMN updated_at <>;" - lReturnedSQL := lSQLBuilder.AlterTableAddColumn(*lUserTableName, *lUpdatedAtColumn) + lReturnedSQL := lSQLBuilder.AlterTableAddColumn(lUserTableName, lUpdatedAtColumn) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.AlterTableAddColumn") } func TestAlterTableDropColumn(t *testing.T) { - lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil) + lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil, nil) lUserTableName := testutils.FakeUserTableName(t) lUpdatedAtColumnName := testutils.FakeUpdatedAtColumnName(t) lExpectedSQL := "ALTER TABLE user DROP COLUMN updated_at;" - lReturnedSQL := lSQLBuilder.AlterTableDropColumn(*lUserTableName, *lUpdatedAtColumnName) + lReturnedSQL := lSQLBuilder.AlterTableDropColumn(lUserTableName, lUpdatedAtColumnName) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.AlterTableDropColumn") } func TestInsert(t *testing.T) { - lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil) + lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil, nil) lReturnedSQL := lSQLBuilder.Insert("product", []string{"id", "sku", "name", "description"}) lExpectedSQL := "insert into product (id,sku,name,description) values (@id,@sku,@name,@description)" @@ -64,7 +64,7 @@ func TestInsert(t *testing.T) { } func TestInsertWithValues(t *testing.T) { - lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil) + lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil, nil) lReturnedSQL, lReturnedValues := lSQLBuilder.InsertWithValues("product", ormshift.ColumnsValues{"id": 1, "sku": "1.005.12.9", "name": "Trufa Sabor Amarula 30g Cacaushow"}) lExpectedSQL := "insert into product (id,name,sku) values (@id,@name,@sku)" @@ -76,7 +76,7 @@ func TestInsertWithValues(t *testing.T) { } func TestUpdate(t *testing.T) { - lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil) + lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil, nil) lReturnedSQL := lSQLBuilder.Update("product", []string{"sku", "name", "description"}, []string{"id"}) lExpectedSQL := "update product set sku = @sku,name = @name,description = @description where id = @id" @@ -84,7 +84,7 @@ func TestUpdate(t *testing.T) { } func TestUpdateWithValues(t *testing.T) { - lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil) + lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil, nil) lReturnedSQL, lReturnedValues := lSQLBuilder.UpdateWithValues("product", []string{"sku", "name"}, []string{"id"}, ormshift.ColumnsValues{"id": 1, "sku": "1.005.12.5", "name": "Trufa Sabor Amarula 18g Cacaushow"}) lExpectedSQL := "update product set sku = @sku,name = @name where id = @id" @@ -96,7 +96,7 @@ func TestUpdateWithValues(t *testing.T) { } func TestDelete(t *testing.T) { - lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil) + lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil, nil) lReturnedSQL := lSQLBuilder.Delete("product", []string{"id"}) lExpectedSQL := "delete from product where id = @id" @@ -104,7 +104,7 @@ func TestDelete(t *testing.T) { } func TestDeleteWithValues(t *testing.T) { - lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil) + lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil, nil) lReturnedSQL, lReturnedValues := lSQLBuilder.DeleteWithValues("product", ormshift.ColumnsValues{"id": 1}) lExpectedSQL := "delete from product where id = @id" @@ -114,7 +114,7 @@ func TestDeleteWithValues(t *testing.T) { } func TestSelect(t *testing.T) { - lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil) + lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil, nil) lReturnedSQL := lSQLBuilder.Select("product", []string{"id", "name", "description"}, []string{"sku", "active"}) lExpectedSQL := "select id,name,description from product where sku = @sku and active = @active" @@ -122,7 +122,7 @@ func TestSelect(t *testing.T) { } func TestSelectWithValues(t *testing.T) { - lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil) + lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil, nil) lReturnedSQL, lReturnedValues := lSQLBuilder.SelectWithValues("product", []string{"id", "sku", "name", "description"}, ormshift.ColumnsValues{"category_id": 1, "active": true}) lExpectedSQL := "select id,sku,name,description from product where active = @active and category_id = @category_id" @@ -133,7 +133,7 @@ func TestSelectWithValues(t *testing.T) { } func TestSelectWithPagination(t *testing.T) { - lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil) + lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil, nil) lReturnedSQL := lSQLBuilder.SelectWithPagination("select * from product", 10, 5) lExpectedSQL := "select * from product LIMIT 10 OFFSET 40" @@ -141,7 +141,7 @@ func TestSelectWithPagination(t *testing.T) { } func TestInteroperateSQLCommandWithNamedArgs(t *testing.T) { - lSQLBuilder := internal.NewGenericSQLBuilder(nil, testutils.FakeInteroperateSQLCommandWithNamedArgsFunc) + lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil, testutils.FakeInteroperateSQLCommandWithNamedArgsFunc) lReturnedSQL, lReturnedNamedArgs := lSQLBuilder.InteroperateSQLCommandWithNamedArgs("original command", sql.NamedArg{Name: "param1", Value: 1}) lExpectedSQL := "command has been modified" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.InteroperateSQLCommandWithNamedArgs.SQL") @@ -150,13 +150,9 @@ func TestInteroperateSQLCommandWithNamedArgs(t *testing.T) { } func TestColumnDefinition(t *testing.T) { - lSQLBuilder := internal.NewGenericSQLBuilder(testutils.FakeColumnDefinitionFunc, nil) - lColumn, lError := schema.NewColumn(schema.NewColumnParams{Name: "column_name", Type: schema.Integer, Size: 0}) - testutils.AssertNilError(t, lError, "schema.NewColumn") - - lTableName, lError := schema.NewTableName("test_table") - testutils.AssertNilError(t, lError, "schema.NewTableName") - - lReturnedSQL := lSQLBuilder.AlterTableAddColumn(*lTableName, *lColumn) + lSQLBuilder := internal.NewGenericSQLBuilder(testutils.FakeColumnDefinitionFunc, nil, nil) + lColumn := schema.NewColumn(schema.NewColumnParams{Name: "column_name", Type: schema.Integer, Size: 0}) + lTableName := "test_table" + lReturnedSQL := lSQLBuilder.AlterTableAddColumn(lTableName, lColumn) testutils.AssertEqualWithLabel(t, "ALTER TABLE test_table ADD COLUMN fake;", lReturnedSQL, "SQLBuilder.ColumnDefinition") } From 674146d6bdf44f8b3749a4d7c30067eee9389210 Mon Sep 17 00:00:00 2001 From: Guilherme Raduenz Date: Tue, 27 Jan 2026 13:28:57 -0300 Subject: [PATCH 05/26] implemented quoting identifiers on each dialect --- dialects/postgresql/builder.go | 22 +++++++++++++++------- dialects/sqlite/builder.go | 29 +++++++++++++++++++---------- dialects/sqlserver/builder.go | 30 ++++++++++++++++++++---------- 3 files changed, 54 insertions(+), 27 deletions(-) diff --git a/dialects/postgresql/builder.go b/dialects/postgresql/builder.go index f8c1826..7e22e10 100644 --- a/dialects/postgresql/builder.go +++ b/dialects/postgresql/builder.go @@ -16,24 +16,24 @@ type postgresqlBuilder struct { } func newPostgreSQLBuilder() ormshift.SQLBuilder { - lBuilder := postgresqlBuilder{} - lBuilder.generic = internal.NewGenericSQLBuilder(lBuilder.columnDefinition, lBuilder.InteroperateSQLCommandWithNamedArgs) - return lBuilder + sb := postgresqlBuilder{} + sb.generic = internal.NewGenericSQLBuilder(sb.columnDefinition, sb.quoteIdentifier, sb.InteroperateSQLCommandWithNamedArgs) + return sb } func (sb postgresqlBuilder) CreateTable(pTable schema.Table) string { return sb.generic.CreateTable(pTable) } -func (sb postgresqlBuilder) DropTable(pTableName schema.TableName) string { +func (sb postgresqlBuilder) DropTable(pTableName string) string { return sb.generic.DropTable(pTableName) } -func (sb postgresqlBuilder) AlterTableAddColumn(pTableName schema.TableName, pColumn schema.Column) string { +func (sb postgresqlBuilder) AlterTableAddColumn(pTableName string, pColumn schema.Column) string { return sb.generic.AlterTableAddColumn(pTableName, pColumn) } -func (sb postgresqlBuilder) AlterTableDropColumn(pTableName schema.TableName, pColumnName schema.ColumnName) string { +func (sb postgresqlBuilder) AlterTableDropColumn(pTableName string, pColumnName string) string { return sb.generic.AlterTableDropColumn(pTableName, pColumnName) } @@ -59,7 +59,7 @@ func (sb postgresqlBuilder) ColumnTypeAsString(pColumnType schema.ColumnType) st } func (sb postgresqlBuilder) columnDefinition(pColumn schema.Column) string { - lColumnDef := pColumn.Name().String() + lColumnDef := sb.quoteIdentifier(pColumn.Name()) if pColumn.AutoIncrement() { lColumnDef += " BIGSERIAL" } else { @@ -111,6 +111,14 @@ func (sb postgresqlBuilder) SelectWithPagination(pSQLSelectCommand string, pRows return sb.generic.SelectWithPagination(pSQLSelectCommand, pRowsPerPage, pPageNumber) } +func (sb postgresqlBuilder) quoteIdentifier(pIdentifier string) string { + // PostgreSQL uses double quotes: "identifier" + // Escape rule: double quote becomes two double quotes + // Example: users -> "users", table"name -> "table""name" + pIdentifier = strings.ReplaceAll(pIdentifier, `"`, `""`) + return fmt.Sprintf(`"%s"`, pIdentifier) +} + func (sb postgresqlBuilder) InteroperateSQLCommandWithNamedArgs(pSQLCommand string, pNamedArgs ...sql.NamedArg) (string, []any) { lSQLCommand := pSQLCommand lArgs := []any{} diff --git a/dialects/sqlite/builder.go b/dialects/sqlite/builder.go index 07cc112..1633bad 100644 --- a/dialects/sqlite/builder.go +++ b/dialects/sqlite/builder.go @@ -3,6 +3,7 @@ package sqlite import ( "database/sql" "fmt" + "strings" "github.com/ordershift/ormshift" "github.com/ordershift/ormshift/internal" @@ -14,9 +15,9 @@ type sqliteBuilder struct { } func newSQLiteBuilder() ormshift.SQLBuilder { - lBuilder := sqliteBuilder{} - lBuilder.generic = internal.NewGenericSQLBuilder(lBuilder.columnDefinition, nil) - return lBuilder + sb := sqliteBuilder{} + sb.generic = internal.NewGenericSQLBuilder(sb.columnDefinition, sb.quoteIdentifier, nil) + return sb } func (sb sqliteBuilder) CreateTable(pTable schema.Table) string { @@ -33,7 +34,7 @@ func (sb sqliteBuilder) CreateTable(pTable schema.Table) string { if lPKColumns != "" { lPKColumns += "," } - lPKColumns += lColumn.Name().String() + lPKColumns += lColumn.Name() } if !lHasAutoIncrementColumn { @@ -45,20 +46,21 @@ func (sb sqliteBuilder) CreateTable(pTable schema.Table) string { if lColumns != "" { lColumns += "," } - lColumns += fmt.Sprintf("CONSTRAINT PK_%s PRIMARY KEY (%s)", pTable.Name().String(), lPKColumns) + lPKConstraintName := sb.quoteIdentifier("PK_" + pTable.Name()) + lColumns += fmt.Sprintf("CONSTRAINT %s PRIMARY KEY (%s)", lPKConstraintName, lPKColumns) } - return fmt.Sprintf("CREATE TABLE %s (%s);", pTable.Name().String(), lColumns) + return fmt.Sprintf("CREATE TABLE %s (%s);", sb.quoteIdentifier(pTable.Name()), lColumns) } -func (sb sqliteBuilder) DropTable(pTableName schema.TableName) string { +func (sb sqliteBuilder) DropTable(pTableName string) string { return sb.generic.DropTable(pTableName) } -func (sb sqliteBuilder) AlterTableAddColumn(pTableName schema.TableName, pColumn schema.Column) string { +func (sb sqliteBuilder) AlterTableAddColumn(pTableName string, pColumn schema.Column) string { return sb.generic.AlterTableAddColumn(pTableName, pColumn) } -func (sb sqliteBuilder) AlterTableDropColumn(pTableName schema.TableName, pColumnName schema.ColumnName) string { +func (sb sqliteBuilder) AlterTableDropColumn(pTableName string, pColumnName string) string { return sb.generic.AlterTableDropColumn(pTableName, pColumnName) } @@ -84,7 +86,7 @@ func (sb sqliteBuilder) ColumnTypeAsString(pColumnType schema.ColumnType) string } func (sb sqliteBuilder) columnDefinition(pColumn schema.Column) string { - lColumnDef := fmt.Sprintf("%s %s", pColumn.Name().String(), sb.ColumnTypeAsString(pColumn.Type())) + lColumnDef := fmt.Sprintf("%s %s", sb.quoteIdentifier(pColumn.Name()), sb.ColumnTypeAsString(pColumn.Type())) if pColumn.NotNull() { lColumnDef += " NOT NULL" } @@ -130,6 +132,13 @@ func (sb sqliteBuilder) SelectWithPagination(pSQLSelectCommand string, pRowsPerP return sb.generic.SelectWithPagination(pSQLSelectCommand, pRowsPerPage, pPageNumber) } +func (sb sqliteBuilder) quoteIdentifier(pIdentifier string) string { + // SQLite uses double quotes (same as PostgreSQL) + // Escape rule: double quote becomes two double quotes + pIdentifier = strings.ReplaceAll(pIdentifier, `"`, `""`) + return fmt.Sprintf(`"%s"`, pIdentifier) +} + func (sb sqliteBuilder) InteroperateSQLCommandWithNamedArgs(pSQLCommand string, pNamedArgs ...sql.NamedArg) (string, []any) { return sb.generic.InteroperateSQLCommandWithNamedArgs(pSQLCommand, pNamedArgs...) } diff --git a/dialects/sqlserver/builder.go b/dialects/sqlserver/builder.go index 5227a83..2cc165f 100644 --- a/dialects/sqlserver/builder.go +++ b/dialects/sqlserver/builder.go @@ -3,6 +3,7 @@ package sqlserver import ( "database/sql" "fmt" + "strings" "github.com/ordershift/ormshift" "github.com/ordershift/ormshift/internal" @@ -14,9 +15,9 @@ type sqlserverBuilder struct { } func newSQLServerBuilder() ormshift.SQLBuilder { - lBuilder := sqlserverBuilder{} - lBuilder.generic = internal.NewGenericSQLBuilder(lBuilder.columnDefinition, nil) - return lBuilder + sb := sqlserverBuilder{} + sb.generic = internal.NewGenericSQLBuilder(sb.columnDefinition, sb.quoteIdentifier, nil) + return sb } func (sb sqlserverBuilder) CreateTable(pTable schema.Table) string { @@ -32,7 +33,7 @@ func (sb sqlserverBuilder) CreateTable(pTable schema.Table) string { if lPKColumns != "" { lPKColumns += "," } - lPKColumns += lColumn.Name().String() + lPKColumns += sb.quoteIdentifier(lColumn.Name()) } } @@ -40,20 +41,21 @@ func (sb sqlserverBuilder) CreateTable(pTable schema.Table) string { if lColumns != "" { lColumns += "," } - lColumns += fmt.Sprintf("CONSTRAINT PK_%s PRIMARY KEY (%s)", pTable.Name().String(), lPKColumns) + lPKConstraintName := sb.quoteIdentifier("PK_" + pTable.Name()) + lColumns += fmt.Sprintf("CONSTRAINT %s PRIMARY KEY (%s)", lPKConstraintName, lPKColumns) } - return fmt.Sprintf("CREATE TABLE %s (%s);", pTable.Name().String(), lColumns) + return fmt.Sprintf("CREATE TABLE %s (%s);", sb.quoteIdentifier(pTable.Name()), lColumns) } -func (sb sqlserverBuilder) DropTable(pTableName schema.TableName) string { +func (sb sqlserverBuilder) DropTable(pTableName string) string { return sb.generic.DropTable(pTableName) } -func (sb sqlserverBuilder) AlterTableAddColumn(pTableName schema.TableName, pColumn schema.Column) string { +func (sb sqlserverBuilder) AlterTableAddColumn(pTableName string, pColumn schema.Column) string { return sb.generic.AlterTableAddColumn(pTableName, pColumn) } -func (sb sqlserverBuilder) AlterTableDropColumn(pTableName schema.TableName, pColumnName schema.ColumnName) string { +func (sb sqlserverBuilder) AlterTableDropColumn(pTableName string, pColumnName string) string { return sb.generic.AlterTableDropColumn(pTableName, pColumnName) } @@ -79,7 +81,7 @@ func (sb sqlserverBuilder) ColumnTypeAsString(pColumnType schema.ColumnType) str } func (sb sqlserverBuilder) columnDefinition(pColumn schema.Column) string { - lColumnDef := pColumn.Name().String() + lColumnDef := sb.quoteIdentifier(pColumn.Name()) if pColumn.Type() == schema.Varchar { lColumnDef += fmt.Sprintf(" %s(%d)", sb.ColumnTypeAsString(pColumn.Type()), pColumn.Size()) } else { @@ -138,6 +140,14 @@ func (sb sqlserverBuilder) SelectWithPagination(pSQLSelectCommand string, pRowsP return lSelectWithPagination } +func (sb sqlserverBuilder) quoteIdentifier(pIdentifier string) string { + // SQL Server uses square brackets: [identifier] + // Escape rule: ] becomes ]] + // Example: users -> [users], table]name -> [table]]name] + pIdentifier = strings.ReplaceAll(pIdentifier, "]", "]]") + return fmt.Sprintf("[%s]", pIdentifier) +} + func (sb sqlserverBuilder) InteroperateSQLCommandWithNamedArgs(pSQLCommand string, pNamedArgs ...sql.NamedArg) (string, []any) { return sb.generic.InteroperateSQLCommandWithNamedArgs(pSQLCommand, pNamedArgs...) } From d8f109233bdb53806a3debc3abeb18873bae316f Mon Sep 17 00:00:00 2001 From: Guilherme Raduenz Date: Tue, 27 Jan 2026 13:29:17 -0300 Subject: [PATCH 06/26] fixed dialect builder tests to build --- dialects/postgresql/builder_test.go | 10 +++++----- dialects/sqlite/builder_test.go | 10 +++++----- dialects/sqlserver/builder_test.go | 10 +++++----- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/dialects/postgresql/builder_test.go b/dialects/postgresql/builder_test.go index 698acff..0caf73b 100644 --- a/dialects/postgresql/builder_test.go +++ b/dialects/postgresql/builder_test.go @@ -45,12 +45,12 @@ func TestCreateTable(t *testing.T) { lExpectedSQL := "CREATE TABLE user (id BIGSERIAL NOT NULL,email VARCHAR(80) NOT NULL,name VARCHAR(50) NOT NULL," + "password_hash VARCHAR(256),active SMALLINT,created_at TIMESTAMP(6),user_master BIGINT,master_user_id BIGINT," + "licence_price NUMERIC(17,2),relevance DOUBLE PRECISION,photo BYTEA,any VARCHAR,PRIMARY KEY (id,email));" - lReturnedSQL := lSQLBuilder.CreateTable(*lUserTable) + lReturnedSQL := lSQLBuilder.CreateTable(lUserTable) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.CreateTable") lProductAttributeTable := testutils.FakeProductAttributeTable(t) lExpectedSQL = "CREATE TABLE product_attribute (product_id BIGINT NOT NULL,attribute_id BIGINT NOT NULL,value VARCHAR(75),position BIGINT,PRIMARY KEY (product_id,attribute_id));" - lReturnedSQL = lSQLBuilder.CreateTable(*lProductAttributeTable) + lReturnedSQL = lSQLBuilder.CreateTable(lProductAttributeTable) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.CreateTable") } @@ -59,7 +59,7 @@ func TestDropTable(t *testing.T) { lUserTableName := testutils.FakeUserTableName(t) lExpectedSQL := "DROP TABLE user;" - lReturnedSQL := lSQLBuilder.DropTable(*lUserTableName) + lReturnedSQL := lSQLBuilder.DropTable(lUserTableName) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.DropTable") } @@ -69,7 +69,7 @@ func TestAlterTableAddColumn(t *testing.T) { lUserTableName := testutils.FakeUserTableName(t) lUpdatedAtColumn := testutils.FakeUpdatedAtColumn(t) lExpectedSQL := "ALTER TABLE user ADD COLUMN updated_at TIMESTAMP(6);" - lReturnedSQL := lSQLBuilder.AlterTableAddColumn(*lUserTableName, *lUpdatedAtColumn) + lReturnedSQL := lSQLBuilder.AlterTableAddColumn(lUserTableName, lUpdatedAtColumn) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.AlterTableAddColumn") } @@ -79,7 +79,7 @@ func TestAlterTableDropColumn(t *testing.T) { lUserTableName := testutils.FakeUserTableName(t) lUpdatedAtColumnName := testutils.FakeUpdatedAtColumnName(t) lExpectedSQL := "ALTER TABLE user DROP COLUMN updated_at;" - lReturnedSQL := lSQLBuilder.AlterTableDropColumn(*lUserTableName, *lUpdatedAtColumnName) + lReturnedSQL := lSQLBuilder.AlterTableDropColumn(lUserTableName, lUpdatedAtColumnName) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.AlterTableDropColumn") } diff --git a/dialects/sqlite/builder_test.go b/dialects/sqlite/builder_test.go index 95e8ab9..4876336 100644 --- a/dialects/sqlite/builder_test.go +++ b/dialects/sqlite/builder_test.go @@ -24,12 +24,12 @@ func TestCreateTable(t *testing.T) { lUserTable := testutils.FakeUserTable(t) lExpectedSQL := "CREATE TABLE user (id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,email TEXT NOT NULL,name TEXT NOT NULL," + "password_hash TEXT,active INTEGER,created_at DATETIME,user_master INTEGER,master_user_id INTEGER,licence_price REAL,relevance REAL,photo BLOB,any TEXT);" - lReturnedSQL := lSQLBuilder.CreateTable(*lUserTable) + lReturnedSQL := lSQLBuilder.CreateTable(lUserTable) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.CreateTable") lProductAttributeTable := testutils.FakeProductAttributeTable(t) lExpectedSQL = "CREATE TABLE product_attribute (product_id INTEGER NOT NULL,attribute_id INTEGER NOT NULL,value TEXT,position INTEGER,CONSTRAINT PK_product_attribute PRIMARY KEY (product_id,attribute_id));" - lReturnedSQL = lSQLBuilder.CreateTable(*lProductAttributeTable) + lReturnedSQL = lSQLBuilder.CreateTable(lProductAttributeTable) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.CreateTable") } @@ -38,7 +38,7 @@ func TestDropTable(t *testing.T) { lUserTableName := testutils.FakeUserTableName(t) lExpectedSQL := "DROP TABLE user;" - lReturnedSQL := lSQLBuilder.DropTable(*lUserTableName) + lReturnedSQL := lSQLBuilder.DropTable(lUserTableName) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.DropTable") } @@ -48,7 +48,7 @@ func TestAlterTableAddColumn(t *testing.T) { lUserTableName := testutils.FakeUserTableName(t) lUpdatedAtColumn := testutils.FakeUpdatedAtColumn(t) lExpectedSQL := "ALTER TABLE user ADD COLUMN updated_at DATETIME;" - lReturnedSQL := lSQLBuilder.AlterTableAddColumn(*lUserTableName, *lUpdatedAtColumn) + lReturnedSQL := lSQLBuilder.AlterTableAddColumn(lUserTableName, lUpdatedAtColumn) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.AlterTableAddColumn") } @@ -58,7 +58,7 @@ func TestAlterTableDropColumn(t *testing.T) { lUserTableName := testutils.FakeUserTableName(t) lUpdatedAtColumnName := testutils.FakeUpdatedAtColumnName(t) lExpectedSQL := "ALTER TABLE user DROP COLUMN updated_at;" - lReturnedSQL := lSQLBuilder.AlterTableDropColumn(*lUserTableName, *lUpdatedAtColumnName) + lReturnedSQL := lSQLBuilder.AlterTableDropColumn(lUserTableName, lUpdatedAtColumnName) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.AlterTableDropColumn") } diff --git a/dialects/sqlserver/builder_test.go b/dialects/sqlserver/builder_test.go index 30a811c..baadb47 100644 --- a/dialects/sqlserver/builder_test.go +++ b/dialects/sqlserver/builder_test.go @@ -25,12 +25,12 @@ func TestCreateTable(t *testing.T) { lExpectedSQL := "CREATE TABLE user (id BIGINT NOT NULL IDENTITY (1, 1),email VARCHAR(80) NOT NULL,name VARCHAR(50) NOT NULL," + "password_hash VARCHAR(256),active BIT,created_at DATETIME2(6),user_master BIGINT,master_user_id BIGINT," + "licence_price MONEY,relevance FLOAT,photo VARBINARY(MAX),any VARCHAR,CONSTRAINT PK_user PRIMARY KEY (id,email));" - lReturnedSQL := lSQLBuilder.CreateTable(*lUserTable) + lReturnedSQL := lSQLBuilder.CreateTable(lUserTable) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.CreateTable") lProductAttributeTable := testutils.FakeProductAttributeTable(t) lExpectedSQL = "CREATE TABLE product_attribute (product_id BIGINT NOT NULL,attribute_id BIGINT NOT NULL,value VARCHAR(75),position BIGINT,CONSTRAINT PK_product_attribute PRIMARY KEY (product_id,attribute_id));" - lReturnedSQL = lSQLBuilder.CreateTable(*lProductAttributeTable) + lReturnedSQL = lSQLBuilder.CreateTable(lProductAttributeTable) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.CreateTable") } @@ -39,7 +39,7 @@ func TestDropTable(t *testing.T) { lUserTableName := testutils.FakeUserTableName(t) lExpectedSQL := "DROP TABLE user;" - lReturnedSQL := lSQLBuilder.DropTable(*lUserTableName) + lReturnedSQL := lSQLBuilder.DropTable(lUserTableName) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.DropTable") } @@ -49,7 +49,7 @@ func TestAlterTableAddColumn(t *testing.T) { lUserTableName := testutils.FakeUserTableName(t) lUpdatedAtColumn := testutils.FakeUpdatedAtColumn(t) lExpectedSQL := "ALTER TABLE user ADD COLUMN updated_at DATETIME2(6);" - lReturnedSQL := lSQLBuilder.AlterTableAddColumn(*lUserTableName, *lUpdatedAtColumn) + lReturnedSQL := lSQLBuilder.AlterTableAddColumn(lUserTableName, lUpdatedAtColumn) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.AlterTableAddColumn") } @@ -59,7 +59,7 @@ func TestAlterTableDropColumn(t *testing.T) { lUserTableName := testutils.FakeUserTableName(t) lUpdatedAtColumnName := testutils.FakeUpdatedAtColumnName(t) lExpectedSQL := "ALTER TABLE user DROP COLUMN updated_at;" - lReturnedSQL := lSQLBuilder.AlterTableDropColumn(*lUserTableName, *lUpdatedAtColumnName) + lReturnedSQL := lSQLBuilder.AlterTableDropColumn(lUserTableName, lUpdatedAtColumnName) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.AlterTableDropColumn") } From 815f2f3e0763253a468c579d1b5c9f91305aba11 Mon Sep 17 00:00:00 2001 From: Guilherme Raduenz Date: Tue, 27 Jan 2026 14:09:24 -0300 Subject: [PATCH 07/26] applied quoting mechanism on schema queries --- dialects/postgresql/builder.go | 20 ++++++++++---------- dialects/postgresql/driver.go | 2 +- dialects/postgresql/schema.go | 6 ++++++ dialects/sqlite/builder.go | 18 +++++++++--------- dialects/sqlite/driver.go | 2 +- dialects/sqlite/schema.go | 6 ++++++ dialects/sqlserver/builder.go | 20 ++++++++++---------- dialects/sqlserver/driver.go | 2 +- dialects/sqlserver/schema.go | 6 ++++++ schema/schema.go | 22 ++++++++++++++++------ schema/schema_test.go | 9 +++++++-- 11 files changed, 73 insertions(+), 40 deletions(-) diff --git a/dialects/postgresql/builder.go b/dialects/postgresql/builder.go index 7e22e10..0a756fc 100644 --- a/dialects/postgresql/builder.go +++ b/dialects/postgresql/builder.go @@ -17,7 +17,7 @@ type postgresqlBuilder struct { func newPostgreSQLBuilder() ormshift.SQLBuilder { sb := postgresqlBuilder{} - sb.generic = internal.NewGenericSQLBuilder(sb.columnDefinition, sb.quoteIdentifier, sb.InteroperateSQLCommandWithNamedArgs) + sb.generic = internal.NewGenericSQLBuilder(sb.columnDefinition, QuoteIdentifier, sb.InteroperateSQLCommandWithNamedArgs) return sb } @@ -59,7 +59,7 @@ func (sb postgresqlBuilder) ColumnTypeAsString(pColumnType schema.ColumnType) st } func (sb postgresqlBuilder) columnDefinition(pColumn schema.Column) string { - lColumnDef := sb.quoteIdentifier(pColumn.Name()) + lColumnDef := QuoteIdentifier(pColumn.Name()) if pColumn.AutoIncrement() { lColumnDef += " BIGSERIAL" } else { @@ -111,14 +111,6 @@ func (sb postgresqlBuilder) SelectWithPagination(pSQLSelectCommand string, pRows return sb.generic.SelectWithPagination(pSQLSelectCommand, pRowsPerPage, pPageNumber) } -func (sb postgresqlBuilder) quoteIdentifier(pIdentifier string) string { - // PostgreSQL uses double quotes: "identifier" - // Escape rule: double quote becomes two double quotes - // Example: users -> "users", table"name -> "table""name" - pIdentifier = strings.ReplaceAll(pIdentifier, `"`, `""`) - return fmt.Sprintf(`"%s"`, pIdentifier) -} - func (sb postgresqlBuilder) InteroperateSQLCommandWithNamedArgs(pSQLCommand string, pNamedArgs ...sql.NamedArg) (string, []any) { lSQLCommand := pSQLCommand lArgs := []any{} @@ -146,3 +138,11 @@ func (sb postgresqlBuilder) InteroperateSQLCommandWithNamedArgs(pSQLCommand stri }) return lSQLCommand, lArgs } + +func QuoteIdentifier(pIdentifier string) string { + // PostgreSQL uses double quotes: "identifier" + // Escape rule: double quote becomes two double quotes + // Example: users -> "users", table"name -> "table""name" + pIdentifier = strings.ReplaceAll(pIdentifier, `"`, `""`) + return fmt.Sprintf(`"%s"`, pIdentifier) +} diff --git a/dialects/postgresql/driver.go b/dialects/postgresql/driver.go index 4962e14..15e385e 100644 --- a/dialects/postgresql/driver.go +++ b/dialects/postgresql/driver.go @@ -38,5 +38,5 @@ func (d postgresqlDriver) SQLBuilder() ormshift.SQLBuilder { } func (d postgresqlDriver) DBSchema(pDB *sql.DB) (*schema.DBSchema, error) { - return schema.NewDBSchema(pDB, tableNamesQuery) + return schema.NewDBSchema(pDB, tableNamesQuery, columnTypesQueryFunc) } diff --git a/dialects/postgresql/schema.go b/dialects/postgresql/schema.go index 00e5d79..a6d69a2 100644 --- a/dialects/postgresql/schema.go +++ b/dialects/postgresql/schema.go @@ -1,5 +1,7 @@ package postgresql +import "fmt" + const tableNamesQuery = ` SELECT table_name @@ -11,3 +13,7 @@ const tableNamesQuery = ` ORDER BY table_name ` + +func columnTypesQueryFunc(pTableName string) string { + return fmt.Sprintf("SELECT * FROM %s WHERE 1=0", QuoteIdentifier(pTableName)) +} diff --git a/dialects/sqlite/builder.go b/dialects/sqlite/builder.go index 1633bad..875447b 100644 --- a/dialects/sqlite/builder.go +++ b/dialects/sqlite/builder.go @@ -16,7 +16,7 @@ type sqliteBuilder struct { func newSQLiteBuilder() ormshift.SQLBuilder { sb := sqliteBuilder{} - sb.generic = internal.NewGenericSQLBuilder(sb.columnDefinition, sb.quoteIdentifier, nil) + sb.generic = internal.NewGenericSQLBuilder(sb.columnDefinition, QuoteIdentifier, nil) return sb } @@ -46,10 +46,10 @@ func (sb sqliteBuilder) CreateTable(pTable schema.Table) string { if lColumns != "" { lColumns += "," } - lPKConstraintName := sb.quoteIdentifier("PK_" + pTable.Name()) + lPKConstraintName := QuoteIdentifier("PK_" + pTable.Name()) lColumns += fmt.Sprintf("CONSTRAINT %s PRIMARY KEY (%s)", lPKConstraintName, lPKColumns) } - return fmt.Sprintf("CREATE TABLE %s (%s);", sb.quoteIdentifier(pTable.Name()), lColumns) + return fmt.Sprintf("CREATE TABLE %s (%s);", QuoteIdentifier(pTable.Name()), lColumns) } func (sb sqliteBuilder) DropTable(pTableName string) string { @@ -86,7 +86,7 @@ func (sb sqliteBuilder) ColumnTypeAsString(pColumnType schema.ColumnType) string } func (sb sqliteBuilder) columnDefinition(pColumn schema.Column) string { - lColumnDef := fmt.Sprintf("%s %s", sb.quoteIdentifier(pColumn.Name()), sb.ColumnTypeAsString(pColumn.Type())) + lColumnDef := fmt.Sprintf("%s %s", QuoteIdentifier(pColumn.Name()), sb.ColumnTypeAsString(pColumn.Type())) if pColumn.NotNull() { lColumnDef += " NOT NULL" } @@ -132,13 +132,13 @@ func (sb sqliteBuilder) SelectWithPagination(pSQLSelectCommand string, pRowsPerP return sb.generic.SelectWithPagination(pSQLSelectCommand, pRowsPerPage, pPageNumber) } -func (sb sqliteBuilder) quoteIdentifier(pIdentifier string) string { +func (sb sqliteBuilder) InteroperateSQLCommandWithNamedArgs(pSQLCommand string, pNamedArgs ...sql.NamedArg) (string, []any) { + return sb.generic.InteroperateSQLCommandWithNamedArgs(pSQLCommand, pNamedArgs...) +} + +func QuoteIdentifier(pIdentifier string) string { // SQLite uses double quotes (same as PostgreSQL) // Escape rule: double quote becomes two double quotes pIdentifier = strings.ReplaceAll(pIdentifier, `"`, `""`) return fmt.Sprintf(`"%s"`, pIdentifier) } - -func (sb sqliteBuilder) InteroperateSQLCommandWithNamedArgs(pSQLCommand string, pNamedArgs ...sql.NamedArg) (string, []any) { - return sb.generic.InteroperateSQLCommandWithNamedArgs(pSQLCommand, pNamedArgs...) -} diff --git a/dialects/sqlite/driver.go b/dialects/sqlite/driver.go index 9c5c808..39f9d28 100644 --- a/dialects/sqlite/driver.go +++ b/dialects/sqlite/driver.go @@ -40,5 +40,5 @@ func (d sqliteDriver) SQLBuilder() ormshift.SQLBuilder { } func (d sqliteDriver) DBSchema(pDB *sql.DB) (*schema.DBSchema, error) { - return schema.NewDBSchema(pDB, tableNamesQuery) + return schema.NewDBSchema(pDB, tableNamesQuery, columnTypesQueryFunc) } diff --git a/dialects/sqlite/schema.go b/dialects/sqlite/schema.go index fa284eb..056faa5 100644 --- a/dialects/sqlite/schema.go +++ b/dialects/sqlite/schema.go @@ -1,5 +1,7 @@ package sqlite +import "fmt" + const tableNamesQuery = ` SELECT name @@ -10,3 +12,7 @@ const tableNamesQuery = ` ORDER BY name ` + +func columnTypesQueryFunc(pTableName string) string { + return fmt.Sprintf("SELECT * FROM %s WHERE 1=0", QuoteIdentifier(pTableName)) +} diff --git a/dialects/sqlserver/builder.go b/dialects/sqlserver/builder.go index 2cc165f..31f5b72 100644 --- a/dialects/sqlserver/builder.go +++ b/dialects/sqlserver/builder.go @@ -16,7 +16,7 @@ type sqlserverBuilder struct { func newSQLServerBuilder() ormshift.SQLBuilder { sb := sqlserverBuilder{} - sb.generic = internal.NewGenericSQLBuilder(sb.columnDefinition, sb.quoteIdentifier, nil) + sb.generic = internal.NewGenericSQLBuilder(sb.columnDefinition, QuoteIdentifier, nil) return sb } @@ -33,7 +33,7 @@ func (sb sqlserverBuilder) CreateTable(pTable schema.Table) string { if lPKColumns != "" { lPKColumns += "," } - lPKColumns += sb.quoteIdentifier(lColumn.Name()) + lPKColumns += QuoteIdentifier(lColumn.Name()) } } @@ -41,10 +41,10 @@ func (sb sqlserverBuilder) CreateTable(pTable schema.Table) string { if lColumns != "" { lColumns += "," } - lPKConstraintName := sb.quoteIdentifier("PK_" + pTable.Name()) + lPKConstraintName := QuoteIdentifier("PK_" + pTable.Name()) lColumns += fmt.Sprintf("CONSTRAINT %s PRIMARY KEY (%s)", lPKConstraintName, lPKColumns) } - return fmt.Sprintf("CREATE TABLE %s (%s);", sb.quoteIdentifier(pTable.Name()), lColumns) + return fmt.Sprintf("CREATE TABLE %s (%s);", QuoteIdentifier(pTable.Name()), lColumns) } func (sb sqlserverBuilder) DropTable(pTableName string) string { @@ -81,7 +81,7 @@ func (sb sqlserverBuilder) ColumnTypeAsString(pColumnType schema.ColumnType) str } func (sb sqlserverBuilder) columnDefinition(pColumn schema.Column) string { - lColumnDef := sb.quoteIdentifier(pColumn.Name()) + lColumnDef := QuoteIdentifier(pColumn.Name()) if pColumn.Type() == schema.Varchar { lColumnDef += fmt.Sprintf(" %s(%d)", sb.ColumnTypeAsString(pColumn.Type()), pColumn.Size()) } else { @@ -140,14 +140,14 @@ func (sb sqlserverBuilder) SelectWithPagination(pSQLSelectCommand string, pRowsP return lSelectWithPagination } -func (sb sqlserverBuilder) quoteIdentifier(pIdentifier string) string { +func (sb sqlserverBuilder) InteroperateSQLCommandWithNamedArgs(pSQLCommand string, pNamedArgs ...sql.NamedArg) (string, []any) { + return sb.generic.InteroperateSQLCommandWithNamedArgs(pSQLCommand, pNamedArgs...) +} + +func QuoteIdentifier(pIdentifier string) string { // SQL Server uses square brackets: [identifier] // Escape rule: ] becomes ]] // Example: users -> [users], table]name -> [table]]name] pIdentifier = strings.ReplaceAll(pIdentifier, "]", "]]") return fmt.Sprintf("[%s]", pIdentifier) } - -func (sb sqlserverBuilder) InteroperateSQLCommandWithNamedArgs(pSQLCommand string, pNamedArgs ...sql.NamedArg) (string, []any) { - return sb.generic.InteroperateSQLCommandWithNamedArgs(pSQLCommand, pNamedArgs...) -} diff --git a/dialects/sqlserver/driver.go b/dialects/sqlserver/driver.go index d63ff96..39c316b 100644 --- a/dialects/sqlserver/driver.go +++ b/dialects/sqlserver/driver.go @@ -37,5 +37,5 @@ func (d sqlserverDriver) SQLBuilder() ormshift.SQLBuilder { } func (d sqlserverDriver) DBSchema(pDB *sql.DB) (*schema.DBSchema, error) { - return schema.NewDBSchema(pDB, tableNamesQuery) + return schema.NewDBSchema(pDB, tableNamesQuery, columnTypesQueryFunc) } diff --git a/dialects/sqlserver/schema.go b/dialects/sqlserver/schema.go index d4019cf..d22e786 100644 --- a/dialects/sqlserver/schema.go +++ b/dialects/sqlserver/schema.go @@ -1,5 +1,7 @@ package sqlserver +import "fmt" + const tableNamesQuery = ` SELECT t.name @@ -15,3 +17,7 @@ const tableNamesQuery = ` ORDER BY t.name ` + +func columnTypesQueryFunc(pTableName string) string { + return fmt.Sprintf("SELECT * FROM %s WHERE 1=0", QuoteIdentifier(pTableName)) +} diff --git a/schema/schema.go b/schema/schema.go index dd95ed5..0449812 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -3,21 +3,31 @@ package schema import ( "database/sql" "errors" - "fmt" "slices" "strings" ) type DBSchema struct { - db *sql.DB - tableNamesQuery string + db *sql.DB + tableNamesQuery string + columnTypesQueryFunc ColumnTypesQueryFunc } -func NewDBSchema(pDB *sql.DB, pTableNamesQuery string) (*DBSchema, error) { +type ColumnTypesQueryFunc func(pTableName string) string + +func NewDBSchema( + pDB *sql.DB, + pTableNamesQuery string, + pColumnTypesQueryFunc ColumnTypesQueryFunc, +) (*DBSchema, error) { if pDB == nil { return nil, errors.New("sql.DB cannot be nil") } - return &DBSchema{db: pDB, tableNamesQuery: pTableNamesQuery}, nil + return &DBSchema{ + db: pDB, + tableNamesQuery: pTableNamesQuery, + columnTypesQueryFunc: pColumnTypesQueryFunc, + }, nil } func (s DBSchema) HasTable(pTableName string) bool { @@ -62,7 +72,7 @@ func (s DBSchema) HasColumn(pTableName string, pColumnName string) bool { } func (s DBSchema) fetchColumnTypes(pTableName string) (rColumnTypes []*sql.ColumnType, rError error) { - lRows, rError := s.db.Query(fmt.Sprintf("SELECT * FROM %s WHERE 1=0", pTableName)) // NOSONAR go:S2077 - Dynamic SQL is controlled and sanitized internally + lRows, rError := s.db.Query(s.columnTypesQueryFunc(pTableName)) if rError != nil { return } diff --git a/schema/schema_test.go b/schema/schema_test.go index 22a0a31..7985eb9 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -1,6 +1,7 @@ package schema_test import ( + "fmt" "testing" "github.com/ordershift/ormshift" @@ -9,6 +10,10 @@ import ( "github.com/ordershift/ormshift/schema" ) +func testColumnTypesQueryFunc(pTableName string) string { + return fmt.Sprintf("SELECT * FROM %s WHERE 1=0", pTableName) +} + func TestNewDBSchema(t *testing.T) { lDB, lError := ormshift.OpenDatabase(sqlite.Driver(), ormshift.ConnectionParams{InMemory: true}) if !testutils.AssertNotNilResultAndNilError(t, lDB, lError, "ormshift.OpenDatabase") { @@ -16,14 +21,14 @@ func TestNewDBSchema(t *testing.T) { } defer func() { _ = lDB.Close() }() - lDBSchema, lError := schema.NewDBSchema(lDB.DB(), "query") + lDBSchema, lError := schema.NewDBSchema(lDB.DB(), "query", testColumnTypesQueryFunc) if !testutils.AssertNotNilResultAndNilError(t, lDBSchema, lError, "schema.NewDBSchema") { return } } func TestNewDBSchemaFailsWhenDBIsNil(t *testing.T) { - lDBSchema, lError := schema.NewDBSchema(nil, "query") + lDBSchema, lError := schema.NewDBSchema(nil, "query", testColumnTypesQueryFunc) if !testutils.AssertNilResultAndNotNilError(t, lDBSchema, lError, "schema.NewDBSchema") { return } From 50c2732f85740b79ec6665dd0f5fb37819e5412e Mon Sep 17 00:00:00 2001 From: Guilherme Raduenz Date: Tue, 27 Jan 2026 14:11:37 -0300 Subject: [PATCH 08/26] fixed linting problems --- internal/testutils/migrations.go | 3 +++ migrations/migrator.go | 7 +++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/internal/testutils/migrations.go b/internal/testutils/migrations.go index c8b1b0d..564283c 100644 --- a/internal/testutils/migrations.go +++ b/internal/testutils/migrations.go @@ -63,6 +63,9 @@ func (m M001_Create_Table_User) Up(pMigrator *migrations.Migrator) error { NotNull: false, }, ) + if lError != nil { + return lError + } _, lError = pMigrator.Database().SQLExecutor().Exec(pMigrator.Database().SQLBuilder().CreateTable(lUserTable)) return lError diff --git a/migrations/migrator.go b/migrations/migrator.go index 4304bf3..d819e36 100644 --- a/migrations/migrator.go +++ b/migrations/migrator.go @@ -167,7 +167,7 @@ func ensureMigrationsTableExists(pDatabase ormshift.Database, pConfig MigratorCo if pDatabase.DBSchema().HasTable(lMigrationsTable.Name()) { return nil } - lMigrationsTable.AddColumns( + lError := lMigrationsTable.AddColumns( schema.NewColumnParams{ Name: pConfig.migrationNameColumn, Type: schema.Varchar, @@ -181,7 +181,10 @@ func ensureMigrationsTableExists(pDatabase ormshift.Database, pConfig MigratorCo NotNull: true, }, ) + if lError != nil { + return lError + } - _, lError := pDatabase.SQLExecutor().Exec(pDatabase.SQLBuilder().CreateTable(lMigrationsTable)) // NOSONAR go:S2077 - Dynamic SQL is controlled and sanitized internally + _, lError = pDatabase.SQLExecutor().Exec(pDatabase.SQLBuilder().CreateTable(lMigrationsTable)) // NOSONAR go:S2077 - Dynamic SQL is controlled and sanitized internally return lError } From 7765abc37bd95c41fd4430b5c4e0ceb24c026111 Mon Sep 17 00:00:00 2001 From: Guilherme Raduenz Date: Tue, 27 Jan 2026 14:34:50 -0300 Subject: [PATCH 09/26] changed sqlbuilder to be a single instance under the driver --- dialects/postgresql/driver.go | 12 ++++++++---- dialects/sqlite/driver.go | 12 ++++++++---- dialects/sqlserver/driver.go | 12 ++++++++---- 3 files changed, 24 insertions(+), 12 deletions(-) diff --git a/dialects/postgresql/driver.go b/dialects/postgresql/driver.go index 15e385e..51cda93 100644 --- a/dialects/postgresql/driver.go +++ b/dialects/postgresql/driver.go @@ -11,10 +11,14 @@ import ( "github.com/ordershift/ormshift/schema" ) -type postgresqlDriver struct{} +type postgresqlDriver struct { + sqlBuilder ormshift.SQLBuilder +} func Driver() ormshift.DatabaseDriver { - return postgresqlDriver{} + return postgresqlDriver{ + sqlBuilder: newPostgreSQLBuilder(), + } } func (d postgresqlDriver) Name() string { @@ -34,9 +38,9 @@ func (d postgresqlDriver) ConnectionString(pParams ormshift.ConnectionParams) st } func (d postgresqlDriver) SQLBuilder() ormshift.SQLBuilder { - return newPostgreSQLBuilder() + return d.sqlBuilder } func (d postgresqlDriver) DBSchema(pDB *sql.DB) (*schema.DBSchema, error) { - return schema.NewDBSchema(pDB, tableNamesQuery, columnTypesQueryFunc) + return schema.NewDBSchema(pDB, tableNamesQuery, columnTypesQueryFunc(d.sqlBuilder)) } diff --git a/dialects/sqlite/driver.go b/dialects/sqlite/driver.go index 39f9d28..d8f28a3 100644 --- a/dialects/sqlite/driver.go +++ b/dialects/sqlite/driver.go @@ -11,10 +11,14 @@ import ( "github.com/ordershift/ormshift/schema" ) -type sqliteDriver struct{} +type sqliteDriver struct { + sqlBuilder ormshift.SQLBuilder +} func Driver() ormshift.DatabaseDriver { - return sqliteDriver{} + return sqliteDriver{ + sqlBuilder: newSQLiteBuilder(), + } } func (d sqliteDriver) Name() string { @@ -36,9 +40,9 @@ func (d sqliteDriver) ConnectionString(pParams ormshift.ConnectionParams) string } func (d sqliteDriver) SQLBuilder() ormshift.SQLBuilder { - return newSQLiteBuilder() + return d.sqlBuilder } func (d sqliteDriver) DBSchema(pDB *sql.DB) (*schema.DBSchema, error) { - return schema.NewDBSchema(pDB, tableNamesQuery, columnTypesQueryFunc) + return schema.NewDBSchema(pDB, tableNamesQuery, columnTypesQueryFunc(d.sqlBuilder)) } diff --git a/dialects/sqlserver/driver.go b/dialects/sqlserver/driver.go index 39c316b..f66f709 100644 --- a/dialects/sqlserver/driver.go +++ b/dialects/sqlserver/driver.go @@ -11,10 +11,14 @@ import ( "github.com/ordershift/ormshift/schema" ) -type sqlserverDriver struct{} +type sqlserverDriver struct { + sqlBuilder ormshift.SQLBuilder +} func Driver() ormshift.DatabaseDriver { - return sqlserverDriver{} + return sqlserverDriver{ + sqlBuilder: newSQLServerBuilder(), + } } func (d sqlserverDriver) Name() string { @@ -33,9 +37,9 @@ func (d sqlserverDriver) ConnectionString(pParams ormshift.ConnectionParams) str } func (d sqlserverDriver) SQLBuilder() ormshift.SQLBuilder { - return newSQLServerBuilder() + return d.sqlBuilder } func (d sqlserverDriver) DBSchema(pDB *sql.DB) (*schema.DBSchema, error) { - return schema.NewDBSchema(pDB, tableNamesQuery, columnTypesQueryFunc) + return schema.NewDBSchema(pDB, tableNamesQuery, columnTypesQueryFunc(d.sqlBuilder)) } From d84a0b2a50b975b895987e6fcb0ad9c837f2df31 Mon Sep 17 00:00:00 2001 From: Guilherme Raduenz Date: Tue, 27 Jan 2026 14:35:32 -0300 Subject: [PATCH 10/26] set QuoteIdentifier function as part of SQLBuilder interface --- builder.go | 1 + dialects/postgresql/builder.go | 20 ++++++++++---------- dialects/sqlite/builder.go | 18 +++++++++--------- dialects/sqlserver/builder.go | 20 ++++++++++---------- internal/builder_generic.go | 4 ++++ 5 files changed, 34 insertions(+), 29 deletions(-) diff --git a/builder.go b/builder.go index cbf4731..5a6b984 100644 --- a/builder.go +++ b/builder.go @@ -91,4 +91,5 @@ type DMLSQLBuilder interface { type SQLBuilder interface { DDLSQLBuilder DMLSQLBuilder + QuoteIdentifier(pIdentifier string) string } diff --git a/dialects/postgresql/builder.go b/dialects/postgresql/builder.go index 0a756fc..773a48c 100644 --- a/dialects/postgresql/builder.go +++ b/dialects/postgresql/builder.go @@ -17,7 +17,7 @@ type postgresqlBuilder struct { func newPostgreSQLBuilder() ormshift.SQLBuilder { sb := postgresqlBuilder{} - sb.generic = internal.NewGenericSQLBuilder(sb.columnDefinition, QuoteIdentifier, sb.InteroperateSQLCommandWithNamedArgs) + sb.generic = internal.NewGenericSQLBuilder(sb.columnDefinition, sb.QuoteIdentifier, sb.InteroperateSQLCommandWithNamedArgs) return sb } @@ -59,7 +59,7 @@ func (sb postgresqlBuilder) ColumnTypeAsString(pColumnType schema.ColumnType) st } func (sb postgresqlBuilder) columnDefinition(pColumn schema.Column) string { - lColumnDef := QuoteIdentifier(pColumn.Name()) + lColumnDef := sb.QuoteIdentifier(pColumn.Name()) if pColumn.AutoIncrement() { lColumnDef += " BIGSERIAL" } else { @@ -111,6 +111,14 @@ func (sb postgresqlBuilder) SelectWithPagination(pSQLSelectCommand string, pRows return sb.generic.SelectWithPagination(pSQLSelectCommand, pRowsPerPage, pPageNumber) } +func (sb postgresqlBuilder) QuoteIdentifier(pIdentifier string) string { + // PostgreSQL uses double quotes: "identifier" + // Escape rule: double quote becomes two double quotes + // Example: users -> "users", table"name -> "table""name" + pIdentifier = strings.ReplaceAll(pIdentifier, `"`, `""`) + return fmt.Sprintf(`"%s"`, pIdentifier) +} + func (sb postgresqlBuilder) InteroperateSQLCommandWithNamedArgs(pSQLCommand string, pNamedArgs ...sql.NamedArg) (string, []any) { lSQLCommand := pSQLCommand lArgs := []any{} @@ -138,11 +146,3 @@ func (sb postgresqlBuilder) InteroperateSQLCommandWithNamedArgs(pSQLCommand stri }) return lSQLCommand, lArgs } - -func QuoteIdentifier(pIdentifier string) string { - // PostgreSQL uses double quotes: "identifier" - // Escape rule: double quote becomes two double quotes - // Example: users -> "users", table"name -> "table""name" - pIdentifier = strings.ReplaceAll(pIdentifier, `"`, `""`) - return fmt.Sprintf(`"%s"`, pIdentifier) -} diff --git a/dialects/sqlite/builder.go b/dialects/sqlite/builder.go index 875447b..6413939 100644 --- a/dialects/sqlite/builder.go +++ b/dialects/sqlite/builder.go @@ -16,7 +16,7 @@ type sqliteBuilder struct { func newSQLiteBuilder() ormshift.SQLBuilder { sb := sqliteBuilder{} - sb.generic = internal.NewGenericSQLBuilder(sb.columnDefinition, QuoteIdentifier, nil) + sb.generic = internal.NewGenericSQLBuilder(sb.columnDefinition, sb.QuoteIdentifier, nil) return sb } @@ -46,10 +46,10 @@ func (sb sqliteBuilder) CreateTable(pTable schema.Table) string { if lColumns != "" { lColumns += "," } - lPKConstraintName := QuoteIdentifier("PK_" + pTable.Name()) + lPKConstraintName := sb.QuoteIdentifier("PK_" + pTable.Name()) lColumns += fmt.Sprintf("CONSTRAINT %s PRIMARY KEY (%s)", lPKConstraintName, lPKColumns) } - return fmt.Sprintf("CREATE TABLE %s (%s);", QuoteIdentifier(pTable.Name()), lColumns) + return fmt.Sprintf("CREATE TABLE %s (%s);", sb.QuoteIdentifier(pTable.Name()), lColumns) } func (sb sqliteBuilder) DropTable(pTableName string) string { @@ -86,7 +86,7 @@ func (sb sqliteBuilder) ColumnTypeAsString(pColumnType schema.ColumnType) string } func (sb sqliteBuilder) columnDefinition(pColumn schema.Column) string { - lColumnDef := fmt.Sprintf("%s %s", QuoteIdentifier(pColumn.Name()), sb.ColumnTypeAsString(pColumn.Type())) + lColumnDef := fmt.Sprintf("%s %s", sb.QuoteIdentifier(pColumn.Name()), sb.ColumnTypeAsString(pColumn.Type())) if pColumn.NotNull() { lColumnDef += " NOT NULL" } @@ -132,13 +132,13 @@ func (sb sqliteBuilder) SelectWithPagination(pSQLSelectCommand string, pRowsPerP return sb.generic.SelectWithPagination(pSQLSelectCommand, pRowsPerPage, pPageNumber) } -func (sb sqliteBuilder) InteroperateSQLCommandWithNamedArgs(pSQLCommand string, pNamedArgs ...sql.NamedArg) (string, []any) { - return sb.generic.InteroperateSQLCommandWithNamedArgs(pSQLCommand, pNamedArgs...) -} - -func QuoteIdentifier(pIdentifier string) string { +func (sb sqliteBuilder) QuoteIdentifier(pIdentifier string) string { // SQLite uses double quotes (same as PostgreSQL) // Escape rule: double quote becomes two double quotes pIdentifier = strings.ReplaceAll(pIdentifier, `"`, `""`) return fmt.Sprintf(`"%s"`, pIdentifier) } + +func (sb sqliteBuilder) InteroperateSQLCommandWithNamedArgs(pSQLCommand string, pNamedArgs ...sql.NamedArg) (string, []any) { + return sb.generic.InteroperateSQLCommandWithNamedArgs(pSQLCommand, pNamedArgs...) +} diff --git a/dialects/sqlserver/builder.go b/dialects/sqlserver/builder.go index 31f5b72..6016739 100644 --- a/dialects/sqlserver/builder.go +++ b/dialects/sqlserver/builder.go @@ -16,7 +16,7 @@ type sqlserverBuilder struct { func newSQLServerBuilder() ormshift.SQLBuilder { sb := sqlserverBuilder{} - sb.generic = internal.NewGenericSQLBuilder(sb.columnDefinition, QuoteIdentifier, nil) + sb.generic = internal.NewGenericSQLBuilder(sb.columnDefinition, sb.QuoteIdentifier, nil) return sb } @@ -33,7 +33,7 @@ func (sb sqlserverBuilder) CreateTable(pTable schema.Table) string { if lPKColumns != "" { lPKColumns += "," } - lPKColumns += QuoteIdentifier(lColumn.Name()) + lPKColumns += sb.QuoteIdentifier(lColumn.Name()) } } @@ -41,10 +41,10 @@ func (sb sqlserverBuilder) CreateTable(pTable schema.Table) string { if lColumns != "" { lColumns += "," } - lPKConstraintName := QuoteIdentifier("PK_" + pTable.Name()) + lPKConstraintName := sb.QuoteIdentifier("PK_" + pTable.Name()) lColumns += fmt.Sprintf("CONSTRAINT %s PRIMARY KEY (%s)", lPKConstraintName, lPKColumns) } - return fmt.Sprintf("CREATE TABLE %s (%s);", QuoteIdentifier(pTable.Name()), lColumns) + return fmt.Sprintf("CREATE TABLE %s (%s);", sb.QuoteIdentifier(pTable.Name()), lColumns) } func (sb sqlserverBuilder) DropTable(pTableName string) string { @@ -81,7 +81,7 @@ func (sb sqlserverBuilder) ColumnTypeAsString(pColumnType schema.ColumnType) str } func (sb sqlserverBuilder) columnDefinition(pColumn schema.Column) string { - lColumnDef := QuoteIdentifier(pColumn.Name()) + lColumnDef := sb.QuoteIdentifier(pColumn.Name()) if pColumn.Type() == schema.Varchar { lColumnDef += fmt.Sprintf(" %s(%d)", sb.ColumnTypeAsString(pColumn.Type()), pColumn.Size()) } else { @@ -140,14 +140,14 @@ func (sb sqlserverBuilder) SelectWithPagination(pSQLSelectCommand string, pRowsP return lSelectWithPagination } -func (sb sqlserverBuilder) InteroperateSQLCommandWithNamedArgs(pSQLCommand string, pNamedArgs ...sql.NamedArg) (string, []any) { - return sb.generic.InteroperateSQLCommandWithNamedArgs(pSQLCommand, pNamedArgs...) -} - -func QuoteIdentifier(pIdentifier string) string { +func (sb sqlserverBuilder) QuoteIdentifier(pIdentifier string) string { // SQL Server uses square brackets: [identifier] // Escape rule: ] becomes ]] // Example: users -> [users], table]name -> [table]]name] pIdentifier = strings.ReplaceAll(pIdentifier, "]", "]]") return fmt.Sprintf("[%s]", pIdentifier) } + +func (sb sqlserverBuilder) InteroperateSQLCommandWithNamedArgs(pSQLCommand string, pNamedArgs ...sql.NamedArg) (string, []any) { + return sb.generic.InteroperateSQLCommandWithNamedArgs(pSQLCommand, pNamedArgs...) +} diff --git a/internal/builder_generic.go b/internal/builder_generic.go index 88b1f5f..b3546c9 100644 --- a/internal/builder_generic.go +++ b/internal/builder_generic.go @@ -176,6 +176,10 @@ func (sb genericSQLBuilder) columnEqualNameList(pColumns []string, pSeparator st return lColumnEqualNameList } +func (sb genericSQLBuilder) QuoteIdentifier(pIdentifier string) string { + return sb.quoteIdentifier(pIdentifier) +} + func (sb genericSQLBuilder) InteroperateSQLCommandWithNamedArgs(pSQLCommand string, pNamedArgs ...sql.NamedArg) (string, []any) { if sb.InteroperateSQLCommandWithNamedArgsFunc != nil { return sb.InteroperateSQLCommandWithNamedArgsFunc(pSQLCommand, pNamedArgs...) From f9b4d5c85cfce8786babaad7f2954b8612d7e0e8 Mon Sep 17 00:00:00 2001 From: Guilherme Raduenz Date: Tue, 27 Jan 2026 14:36:22 -0300 Subject: [PATCH 11/26] modified schema functions to rely on sql builder --- dialects/postgresql/schema.go | 12 +++++++++--- dialects/sqlite/schema.go | 12 +++++++++--- dialects/sqlserver/schema.go | 12 +++++++++--- 3 files changed, 27 insertions(+), 9 deletions(-) diff --git a/dialects/postgresql/schema.go b/dialects/postgresql/schema.go index a6d69a2..722d04c 100644 --- a/dialects/postgresql/schema.go +++ b/dialects/postgresql/schema.go @@ -1,6 +1,10 @@ package postgresql -import "fmt" +import ( + "fmt" + + "github.com/ordershift/ormshift" +) const tableNamesQuery = ` SELECT @@ -14,6 +18,8 @@ const tableNamesQuery = ` table_name ` -func columnTypesQueryFunc(pTableName string) string { - return fmt.Sprintf("SELECT * FROM %s WHERE 1=0", QuoteIdentifier(pTableName)) +func columnTypesQueryFunc(pSQLBuilder ormshift.SQLBuilder) func(string) string { + return func(pTableName string) string { + return fmt.Sprintf("SELECT * FROM %s WHERE 1=0", pSQLBuilder.QuoteIdentifier(pTableName)) + } } diff --git a/dialects/sqlite/schema.go b/dialects/sqlite/schema.go index 056faa5..9519537 100644 --- a/dialects/sqlite/schema.go +++ b/dialects/sqlite/schema.go @@ -1,6 +1,10 @@ package sqlite -import "fmt" +import ( + "fmt" + + "github.com/ordershift/ormshift" +) const tableNamesQuery = ` SELECT @@ -13,6 +17,8 @@ const tableNamesQuery = ` name ` -func columnTypesQueryFunc(pTableName string) string { - return fmt.Sprintf("SELECT * FROM %s WHERE 1=0", QuoteIdentifier(pTableName)) +func columnTypesQueryFunc(pSQLBuilder ormshift.SQLBuilder) func(string) string { + return func(pTableName string) string { + return fmt.Sprintf("SELECT * FROM %s WHERE 1=0", pSQLBuilder.QuoteIdentifier(pTableName)) + } } diff --git a/dialects/sqlserver/schema.go b/dialects/sqlserver/schema.go index d22e786..764086c 100644 --- a/dialects/sqlserver/schema.go +++ b/dialects/sqlserver/schema.go @@ -1,6 +1,10 @@ package sqlserver -import "fmt" +import ( + "fmt" + + "github.com/ordershift/ormshift" +) const tableNamesQuery = ` SELECT @@ -18,6 +22,8 @@ const tableNamesQuery = ` t.name ` -func columnTypesQueryFunc(pTableName string) string { - return fmt.Sprintf("SELECT * FROM %s WHERE 1=0", QuoteIdentifier(pTableName)) +func columnTypesQueryFunc(pSQLBuilder ormshift.SQLBuilder) func(string) string { + return func(pTableName string) string { + return fmt.Sprintf("SELECT * FROM %s WHERE 1=0", pSQLBuilder.QuoteIdentifier(pTableName)) + } } From b221e1a39036b01db1916f672bdc598ab69547f5 Mon Sep 17 00:00:00 2001 From: Guilherme Raduenz Date: Tue, 27 Jan 2026 14:36:34 -0300 Subject: [PATCH 12/26] applied quoting mechanism on migrator functions --- migrations/migrator.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/migrations/migrator.go b/migrations/migrator.go index d819e36..4a7d5b6 100644 --- a/migrations/migrator.go +++ b/migrations/migrator.go @@ -137,9 +137,9 @@ func getAppliedMigrationNames(pDatabase ormshift.Database, pConfig MigratorConfi q, p := pDatabase.SQLBuilder().InteroperateSQLCommandWithNamedArgs( fmt.Sprintf( "select %s from %s order by %s", - pConfig.migrationNameColumn, - pConfig.tableName, - pConfig.migrationNameColumn, + pDatabase.SQLBuilder().QuoteIdentifier(pConfig.migrationNameColumn), + pDatabase.SQLBuilder().QuoteIdentifier(pConfig.tableName), + pDatabase.SQLBuilder().QuoteIdentifier(pConfig.migrationNameColumn), ), ) lMigrationsRows, rError := pDatabase.SQLExecutor().Query(q, p...) From cbbfc1eb5fed429da9f982f26139af03caa4aefc Mon Sep 17 00:00:00 2001 From: Guilherme Raduenz Date: Tue, 27 Jan 2026 14:43:37 -0300 Subject: [PATCH 13/26] moved default quoting behavior to generic sql builder --- dialects/postgresql/builder.go | 8 ++----- dialects/sqlite/builder.go | 8 ++----- internal/builder_generic.go | 38 ++++++++++++++++++---------------- 3 files changed, 24 insertions(+), 30 deletions(-) diff --git a/dialects/postgresql/builder.go b/dialects/postgresql/builder.go index 773a48c..785d274 100644 --- a/dialects/postgresql/builder.go +++ b/dialects/postgresql/builder.go @@ -17,7 +17,7 @@ type postgresqlBuilder struct { func newPostgreSQLBuilder() ormshift.SQLBuilder { sb := postgresqlBuilder{} - sb.generic = internal.NewGenericSQLBuilder(sb.columnDefinition, sb.QuoteIdentifier, sb.InteroperateSQLCommandWithNamedArgs) + sb.generic = internal.NewGenericSQLBuilder(sb.columnDefinition, nil, sb.InteroperateSQLCommandWithNamedArgs) return sb } @@ -112,11 +112,7 @@ func (sb postgresqlBuilder) SelectWithPagination(pSQLSelectCommand string, pRows } func (sb postgresqlBuilder) QuoteIdentifier(pIdentifier string) string { - // PostgreSQL uses double quotes: "identifier" - // Escape rule: double quote becomes two double quotes - // Example: users -> "users", table"name -> "table""name" - pIdentifier = strings.ReplaceAll(pIdentifier, `"`, `""`) - return fmt.Sprintf(`"%s"`, pIdentifier) + return sb.generic.QuoteIdentifier(pIdentifier) } func (sb postgresqlBuilder) InteroperateSQLCommandWithNamedArgs(pSQLCommand string, pNamedArgs ...sql.NamedArg) (string, []any) { diff --git a/dialects/sqlite/builder.go b/dialects/sqlite/builder.go index 6413939..8032379 100644 --- a/dialects/sqlite/builder.go +++ b/dialects/sqlite/builder.go @@ -3,7 +3,6 @@ package sqlite import ( "database/sql" "fmt" - "strings" "github.com/ordershift/ormshift" "github.com/ordershift/ormshift/internal" @@ -16,7 +15,7 @@ type sqliteBuilder struct { func newSQLiteBuilder() ormshift.SQLBuilder { sb := sqliteBuilder{} - sb.generic = internal.NewGenericSQLBuilder(sb.columnDefinition, sb.QuoteIdentifier, nil) + sb.generic = internal.NewGenericSQLBuilder(sb.columnDefinition, nil, nil) return sb } @@ -133,10 +132,7 @@ func (sb sqliteBuilder) SelectWithPagination(pSQLSelectCommand string, pRowsPerP } func (sb sqliteBuilder) QuoteIdentifier(pIdentifier string) string { - // SQLite uses double quotes (same as PostgreSQL) - // Escape rule: double quote becomes two double quotes - pIdentifier = strings.ReplaceAll(pIdentifier, `"`, `""`) - return fmt.Sprintf(`"%s"`, pIdentifier) + return sb.generic.QuoteIdentifier(pIdentifier) } func (sb sqliteBuilder) InteroperateSQLCommandWithNamedArgs(pSQLCommand string, pNamedArgs ...sql.NamedArg) (string, []any) { diff --git a/internal/builder_generic.go b/internal/builder_generic.go index b3546c9..36c99a1 100644 --- a/internal/builder_generic.go +++ b/internal/builder_generic.go @@ -33,13 +33,6 @@ func NewGenericSQLBuilder( } } -func (sb genericSQLBuilder) quoteIdentifier(pIdentifier string) string { - if sb.QuoteIdentifierFunc != nil { - return sb.QuoteIdentifierFunc(pIdentifier) - } - return pIdentifier -} - func (sb genericSQLBuilder) CreateTable(pTable schema.Table) string { lColumns := "" lPKColumns := "" @@ -53,7 +46,7 @@ func (sb genericSQLBuilder) CreateTable(pTable schema.Table) string { if lPKColumns != "" { lPKColumns += "," } - lPKColumns += sb.quoteIdentifier(lColumn.Name()) + lPKColumns += sb.QuoteIdentifier(lColumn.Name()) } } @@ -63,19 +56,19 @@ func (sb genericSQLBuilder) CreateTable(pTable schema.Table) string { } lColumns += fmt.Sprintf("PRIMARY KEY (%s)", lPKColumns) } - return fmt.Sprintf("CREATE TABLE %s (%s);", sb.quoteIdentifier(pTable.Name()), lColumns) + return fmt.Sprintf("CREATE TABLE %s (%s);", sb.QuoteIdentifier(pTable.Name()), lColumns) } func (sb genericSQLBuilder) DropTable(pTableName string) string { - return fmt.Sprintf("DROP TABLE %s;", sb.quoteIdentifier(pTableName)) + return fmt.Sprintf("DROP TABLE %s;", sb.QuoteIdentifier(pTableName)) } func (sb genericSQLBuilder) AlterTableAddColumn(pTableName string, pColumn schema.Column) string { - return fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s;", sb.quoteIdentifier(pTableName), sb.columnDefinition(pColumn)) + return fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s;", sb.QuoteIdentifier(pTableName), sb.columnDefinition(pColumn)) } func (sb genericSQLBuilder) AlterTableDropColumn(pTableName string, pColumnName string) string { - return fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s;", sb.quoteIdentifier(pTableName), sb.quoteIdentifier(pColumnName)) + return fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s;", sb.QuoteIdentifier(pTableName), sb.QuoteIdentifier(pColumnName)) } func (sb genericSQLBuilder) ColumnTypeAsString(pColumnType schema.ColumnType) string { @@ -87,11 +80,11 @@ func (sb genericSQLBuilder) columnDefinition(pColumn schema.Column) string { if sb.ColumnDefinitionFunc != nil { return sb.ColumnDefinitionFunc(pColumn) } - return fmt.Sprintf("%s %s", sb.quoteIdentifier(pColumn.Name()), sb.ColumnTypeAsString(pColumn.Type())) + return fmt.Sprintf("%s %s", sb.QuoteIdentifier(pColumn.Name()), sb.ColumnTypeAsString(pColumn.Type())) } func (sb genericSQLBuilder) Insert(pTableName string, pColumns []string) string { - return fmt.Sprintf("insert into %s (%s) values (%s)", sb.quoteIdentifier(pTableName), sb.columnsList(pColumns), sb.namesList(pColumns)) + return fmt.Sprintf("insert into %s (%s) values (%s)", sb.QuoteIdentifier(pTableName), sb.columnsList(pColumns), sb.namesList(pColumns)) } func (sb genericSQLBuilder) InsertWithValues(pTableName string, pColumnsValues ormshift.ColumnsValues) (string, []any) { @@ -101,7 +94,7 @@ func (sb genericSQLBuilder) InsertWithValues(pTableName string, pColumnsValues o } func (sb genericSQLBuilder) Update(pTableName string, pColumns, pColumnsWhere []string) string { - lUpdate := fmt.Sprintf("update %s set %s ", sb.quoteIdentifier(pTableName), sb.columnEqualNameList(pColumns, ",")) + lUpdate := fmt.Sprintf("update %s set %s ", sb.QuoteIdentifier(pTableName), sb.columnEqualNameList(pColumns, ",")) if len(pColumnsWhere) > 0 { lUpdate += fmt.Sprintf("where %s", sb.columnEqualNameList(pColumnsWhere, " and ")) // NOSONAR go:S1192 - duplicate tradeoff accepted } @@ -115,7 +108,7 @@ func (sb genericSQLBuilder) UpdateWithValues(pTableName string, pColumns, pColum } func (sb genericSQLBuilder) Delete(pTableName string, pColumnsWhere []string) string { - lDelete := fmt.Sprintf("delete from %s ", sb.quoteIdentifier(pTableName)) + lDelete := fmt.Sprintf("delete from %s ", sb.QuoteIdentifier(pTableName)) if len(pColumnsWhere) > 0 { lDelete += fmt.Sprintf("where %s", sb.columnEqualNameList(pColumnsWhere, " and ")) // NOSONAR go:S1192 - duplicate tradeoff accepted } @@ -129,7 +122,7 @@ func (sb genericSQLBuilder) DeleteWithValues(pTableName string, pWhereColumnsVal } func (sb genericSQLBuilder) Select(pTableName string, pColumns, pColumnsWhere []string) string { - lUpdate := fmt.Sprintf("select %s from %s ", sb.columnsList(pColumns), sb.quoteIdentifier(pTableName)) + lUpdate := fmt.Sprintf("select %s from %s ", sb.columnsList(pColumns), sb.QuoteIdentifier(pTableName)) if len(pColumnsWhere) > 0 { lUpdate += fmt.Sprintf("where %s", sb.columnEqualNameList(pColumnsWhere, " and ")) // NOSONAR go:S1192 - duplicate tradeoff accepted } @@ -177,13 +170,22 @@ func (sb genericSQLBuilder) columnEqualNameList(pColumns []string, pSeparator st } func (sb genericSQLBuilder) QuoteIdentifier(pIdentifier string) string { - return sb.quoteIdentifier(pIdentifier) + if sb.QuoteIdentifierFunc != nil { + return sb.QuoteIdentifierFunc(pIdentifier) + } + + // Most databases uses double quotes: "identifier" (PostgreSQL, SQLite, etc.) + // Escape rule: double quote becomes two double quotes + // Example: users -> "users", table"name -> "table""name" + pIdentifier = strings.ReplaceAll(pIdentifier, `"`, `""`) + return fmt.Sprintf(`"%s"`, pIdentifier) } func (sb genericSQLBuilder) InteroperateSQLCommandWithNamedArgs(pSQLCommand string, pNamedArgs ...sql.NamedArg) (string, []any) { if sb.InteroperateSQLCommandWithNamedArgsFunc != nil { return sb.InteroperateSQLCommandWithNamedArgsFunc(pSQLCommand, pNamedArgs...) } + lSQLCommand := pSQLCommand lArgs := []any{} for _, lParam := range pNamedArgs { From 932d83de5512ca61202d8aba99ec1029ba919530 Mon Sep 17 00:00:00 2001 From: Guilherme Raduenz Date: Tue, 27 Jan 2026 14:49:37 -0300 Subject: [PATCH 14/26] fixed typo in assert --- internal/testutils/fake.go | 2 +- schema/table_test.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/testutils/fake.go b/internal/testutils/fake.go index f7967f5..48207c0 100644 --- a/internal/testutils/fake.go +++ b/internal/testutils/fake.go @@ -129,7 +129,7 @@ func FakeUserTable(t *testing.T) schema.Table { AutoIncrement: false, }, ) - if !AssertNilError(t, lError, "UserTable.AddColumn") { + if !AssertNilError(t, lError, "UserTable.AddColumns") { panic(lError) } return lUserTable diff --git a/schema/table_test.go b/schema/table_test.go index 51e650a..028c5bd 100644 --- a/schema/table_test.go +++ b/schema/table_test.go @@ -14,8 +14,8 @@ func TestAddColumnFailsWhenAlreadyExists(t *testing.T) { Name: "value", Type: schema.Integer, }) - if !testutils.AssertNotNilError(t, lError, "Table.AddColumn") { + if !testutils.AssertNotNilError(t, lError, "Table.AddColumns") { return } - testutils.AssertErrorMessage(t, fmt.Sprintf("column %q already exists in table %q", "value", "product_attribute"), lError, "Table.AddColumn") + testutils.AssertErrorMessage(t, fmt.Sprintf("column %q already exists in table %q", "value", "product_attribute"), lError, "Table.AddColumns") } From e02e356d3d2eeafb60ead71f1e8653ec590b3527 Mon Sep 17 00:00:00 2001 From: Guilherme Raduenz Date: Tue, 27 Jan 2026 15:26:00 -0300 Subject: [PATCH 15/26] changed structs to have pointer receivers as a standard --- builder.go | 8 +++---- database.go | 11 +++++---- dialects/postgresql/builder.go | 36 ++++++++++++++--------------- dialects/postgresql/driver.go | 10 ++++---- dialects/sqlite/builder.go | 36 ++++++++++++++--------------- dialects/sqlite/driver.go | 10 ++++---- dialects/sqlserver/builder.go | 36 ++++++++++++++--------------- dialects/sqlserver/driver.go | 10 ++++---- internal/builder_generic.go | 42 +++++++++++++++++----------------- migrations/config.go | 20 ++++++++-------- migrations/migrations.go | 2 +- migrations/migrator.go | 34 ++++++++++++++------------- schema/column.go | 12 +++++----- schema/schema.go | 8 +++---- schema/table.go | 6 ++--- 15 files changed, 142 insertions(+), 139 deletions(-) diff --git a/builder.go b/builder.go index 5a6b984..736383a 100644 --- a/builder.go +++ b/builder.go @@ -24,9 +24,9 @@ type ColumnsValues map[string]any // lColumnsValues := ColumnsValues{"id": 5, "sku": "ZTX-9000", "is_simple": true} // lNamedArgs := lColumnsValues.ToNamedArgs() // //lNamedArgs == []sql.NamedArg{{Name: "id", Value: 5},{Name: "is_simple", Value: true},{Name: "sku", Value: "ZTX-9000"}} -func (cv ColumnsValues) ToNamedArgs() []sql.NamedArg { +func (cv *ColumnsValues) ToNamedArgs() []sql.NamedArg { lNamedArgs := []sql.NamedArg{} - for c, v := range cv { + for c, v := range *cv { lNamedArgs = append(lNamedArgs, sql.Named(c, v)) } slices.SortFunc(lNamedArgs, func(a, b sql.NamedArg) int { @@ -39,9 +39,9 @@ func (cv ColumnsValues) ToNamedArgs() []sql.NamedArg { } // ToColumns returns the column names from ColumnsValues as a string array ordered by name, e.g.: -func (cv ColumnsValues) ToColumns() []string { +func (cv *ColumnsValues) ToColumns() []string { lColumns := []string{} - for c := range cv { + for c := range *cv { lColumns = append(lColumns, c) } slices.Sort(lColumns) diff --git a/database.go b/database.go index 059caff..c641af1 100644 --- a/database.go +++ b/database.go @@ -31,7 +31,7 @@ type Database struct { executor SQLExecutor connectionString string sqlBuilder SQLBuilder - schema schema.DBSchema + schema *schema.DBSchema } func OpenDatabase(pDriver DatabaseDriver, pParams ConnectionParams) (*Database, error) { @@ -47,13 +47,16 @@ func OpenDatabase(pDriver DatabaseDriver, pParams ConnectionParams) (*Database, if lError != nil { return nil, fmt.Errorf("failed to get DB schema: %w", lError) } + + // TODO: Unify SQLExecutor interface usage + var lExecutor SQLExecutor = lDB return &Database{ driver: pDriver, db: lDB, - executor: lDB, + executor: lExecutor, connectionString: lConnectionString, sqlBuilder: pDriver.SQLBuilder(), - schema: *lSchema, + schema: lSchema, }, nil } @@ -82,6 +85,6 @@ func (d *Database) SQLBuilder() SQLBuilder { return d.sqlBuilder } -func (d *Database) DBSchema() schema.DBSchema { +func (d *Database) DBSchema() *schema.DBSchema { return d.schema } diff --git a/dialects/postgresql/builder.go b/dialects/postgresql/builder.go index 785d274..9607eac 100644 --- a/dialects/postgresql/builder.go +++ b/dialects/postgresql/builder.go @@ -18,26 +18,26 @@ type postgresqlBuilder struct { func newPostgreSQLBuilder() ormshift.SQLBuilder { sb := postgresqlBuilder{} sb.generic = internal.NewGenericSQLBuilder(sb.columnDefinition, nil, sb.InteroperateSQLCommandWithNamedArgs) - return sb + return &sb } -func (sb postgresqlBuilder) CreateTable(pTable schema.Table) string { +func (sb *postgresqlBuilder) CreateTable(pTable schema.Table) string { return sb.generic.CreateTable(pTable) } -func (sb postgresqlBuilder) DropTable(pTableName string) string { +func (sb *postgresqlBuilder) DropTable(pTableName string) string { return sb.generic.DropTable(pTableName) } -func (sb postgresqlBuilder) AlterTableAddColumn(pTableName string, pColumn schema.Column) string { +func (sb *postgresqlBuilder) AlterTableAddColumn(pTableName string, pColumn schema.Column) string { return sb.generic.AlterTableAddColumn(pTableName, pColumn) } -func (sb postgresqlBuilder) AlterTableDropColumn(pTableName string, pColumnName string) string { +func (sb *postgresqlBuilder) AlterTableDropColumn(pTableName string, pColumnName string) string { return sb.generic.AlterTableDropColumn(pTableName, pColumnName) } -func (sb postgresqlBuilder) ColumnTypeAsString(pColumnType schema.ColumnType) string { +func (sb *postgresqlBuilder) ColumnTypeAsString(pColumnType schema.ColumnType) string { switch pColumnType { case schema.Varchar: return "VARCHAR" @@ -58,7 +58,7 @@ func (sb postgresqlBuilder) ColumnTypeAsString(pColumnType schema.ColumnType) st } } -func (sb postgresqlBuilder) columnDefinition(pColumn schema.Column) string { +func (sb *postgresqlBuilder) columnDefinition(pColumn schema.Column) string { lColumnDef := sb.QuoteIdentifier(pColumn.Name()) if pColumn.AutoIncrement() { lColumnDef += " BIGSERIAL" @@ -75,47 +75,47 @@ func (sb postgresqlBuilder) columnDefinition(pColumn schema.Column) string { return lColumnDef } -func (sb postgresqlBuilder) Insert(pTableName string, pColumns []string) string { +func (sb *postgresqlBuilder) Insert(pTableName string, pColumns []string) string { return sb.generic.Insert(pTableName, pColumns) } -func (sb postgresqlBuilder) InsertWithValues(pTableName string, pColumnsValues ormshift.ColumnsValues) (string, []any) { +func (sb *postgresqlBuilder) InsertWithValues(pTableName string, pColumnsValues ormshift.ColumnsValues) (string, []any) { return sb.generic.InsertWithValues(pTableName, pColumnsValues) } -func (sb postgresqlBuilder) Update(pTableName string, pColumns, pColumnsWhere []string) string { +func (sb *postgresqlBuilder) Update(pTableName string, pColumns, pColumnsWhere []string) string { return sb.generic.Update(pTableName, pColumns, pColumnsWhere) } -func (sb postgresqlBuilder) UpdateWithValues(pTableName string, pColumns, pColumnsWhere []string, pValues ormshift.ColumnsValues) (string, []any) { +func (sb *postgresqlBuilder) UpdateWithValues(pTableName string, pColumns, pColumnsWhere []string, pValues ormshift.ColumnsValues) (string, []any) { return sb.generic.UpdateWithValues(pTableName, pColumns, pColumnsWhere, pValues) } -func (sb postgresqlBuilder) Delete(pTableName string, pColumnsWhere []string) string { +func (sb *postgresqlBuilder) Delete(pTableName string, pColumnsWhere []string) string { return sb.generic.Delete(pTableName, pColumnsWhere) } -func (sb postgresqlBuilder) DeleteWithValues(pTableName string, pWhereColumnsValues ormshift.ColumnsValues) (string, []any) { +func (sb *postgresqlBuilder) DeleteWithValues(pTableName string, pWhereColumnsValues ormshift.ColumnsValues) (string, []any) { return sb.generic.DeleteWithValues(pTableName, pWhereColumnsValues) } -func (sb postgresqlBuilder) Select(pTableName string, pColumns, pColumnsWhere []string) string { +func (sb *postgresqlBuilder) Select(pTableName string, pColumns, pColumnsWhere []string) string { return sb.generic.Select(pTableName, pColumns, pColumnsWhere) } -func (sb postgresqlBuilder) SelectWithValues(pTableName string, pColumns []string, pWhereColumnsValues ormshift.ColumnsValues) (string, []any) { +func (sb *postgresqlBuilder) SelectWithValues(pTableName string, pColumns []string, pWhereColumnsValues ormshift.ColumnsValues) (string, []any) { return sb.generic.SelectWithValues(pTableName, pColumns, pWhereColumnsValues) } -func (sb postgresqlBuilder) SelectWithPagination(pSQLSelectCommand string, pRowsPerPage, pPageNumber uint) string { +func (sb *postgresqlBuilder) SelectWithPagination(pSQLSelectCommand string, pRowsPerPage, pPageNumber uint) string { return sb.generic.SelectWithPagination(pSQLSelectCommand, pRowsPerPage, pPageNumber) } -func (sb postgresqlBuilder) QuoteIdentifier(pIdentifier string) string { +func (sb *postgresqlBuilder) QuoteIdentifier(pIdentifier string) string { return sb.generic.QuoteIdentifier(pIdentifier) } -func (sb postgresqlBuilder) InteroperateSQLCommandWithNamedArgs(pSQLCommand string, pNamedArgs ...sql.NamedArg) (string, []any) { +func (sb *postgresqlBuilder) InteroperateSQLCommandWithNamedArgs(pSQLCommand string, pNamedArgs ...sql.NamedArg) (string, []any) { lSQLCommand := pSQLCommand lArgs := []any{} lIndexes := map[string]int{} diff --git a/dialects/postgresql/driver.go b/dialects/postgresql/driver.go index 51cda93..1123ac7 100644 --- a/dialects/postgresql/driver.go +++ b/dialects/postgresql/driver.go @@ -16,16 +16,16 @@ type postgresqlDriver struct { } func Driver() ormshift.DatabaseDriver { - return postgresqlDriver{ + return &postgresqlDriver{ sqlBuilder: newPostgreSQLBuilder(), } } -func (d postgresqlDriver) Name() string { +func (d *postgresqlDriver) Name() string { return "postgres" } -func (d postgresqlDriver) ConnectionString(pParams ormshift.ConnectionParams) string { +func (d *postgresqlDriver) ConnectionString(pParams ormshift.ConnectionParams) string { lHost := pParams.Host if lHost == "" { lHost = "localhost" @@ -37,10 +37,10 @@ func (d postgresqlDriver) ConnectionString(pParams ormshift.ConnectionParams) st return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", lHost, lPort, pParams.User, pParams.Password, pParams.Database) } -func (d postgresqlDriver) SQLBuilder() ormshift.SQLBuilder { +func (d *postgresqlDriver) SQLBuilder() ormshift.SQLBuilder { return d.sqlBuilder } -func (d postgresqlDriver) DBSchema(pDB *sql.DB) (*schema.DBSchema, error) { +func (d *postgresqlDriver) DBSchema(pDB *sql.DB) (*schema.DBSchema, error) { return schema.NewDBSchema(pDB, tableNamesQuery, columnTypesQueryFunc(d.sqlBuilder)) } diff --git a/dialects/sqlite/builder.go b/dialects/sqlite/builder.go index 8032379..6dff7d4 100644 --- a/dialects/sqlite/builder.go +++ b/dialects/sqlite/builder.go @@ -16,10 +16,10 @@ type sqliteBuilder struct { func newSQLiteBuilder() ormshift.SQLBuilder { sb := sqliteBuilder{} sb.generic = internal.NewGenericSQLBuilder(sb.columnDefinition, nil, nil) - return sb + return &sb } -func (sb sqliteBuilder) CreateTable(pTable schema.Table) string { +func (sb *sqliteBuilder) CreateTable(pTable schema.Table) string { lColumns := "" lPKColumns := "" lHasAutoIncrementColumn := false @@ -51,19 +51,19 @@ func (sb sqliteBuilder) CreateTable(pTable schema.Table) string { return fmt.Sprintf("CREATE TABLE %s (%s);", sb.QuoteIdentifier(pTable.Name()), lColumns) } -func (sb sqliteBuilder) DropTable(pTableName string) string { +func (sb *sqliteBuilder) DropTable(pTableName string) string { return sb.generic.DropTable(pTableName) } -func (sb sqliteBuilder) AlterTableAddColumn(pTableName string, pColumn schema.Column) string { +func (sb *sqliteBuilder) AlterTableAddColumn(pTableName string, pColumn schema.Column) string { return sb.generic.AlterTableAddColumn(pTableName, pColumn) } -func (sb sqliteBuilder) AlterTableDropColumn(pTableName string, pColumnName string) string { +func (sb *sqliteBuilder) AlterTableDropColumn(pTableName string, pColumnName string) string { return sb.generic.AlterTableDropColumn(pTableName, pColumnName) } -func (sb sqliteBuilder) ColumnTypeAsString(pColumnType schema.ColumnType) string { +func (sb *sqliteBuilder) ColumnTypeAsString(pColumnType schema.ColumnType) string { switch pColumnType { case schema.Varchar: return "TEXT" @@ -84,7 +84,7 @@ func (sb sqliteBuilder) ColumnTypeAsString(pColumnType schema.ColumnType) string } } -func (sb sqliteBuilder) columnDefinition(pColumn schema.Column) string { +func (sb *sqliteBuilder) columnDefinition(pColumn schema.Column) string { lColumnDef := fmt.Sprintf("%s %s", sb.QuoteIdentifier(pColumn.Name()), sb.ColumnTypeAsString(pColumn.Type())) if pColumn.NotNull() { lColumnDef += " NOT NULL" @@ -95,46 +95,46 @@ func (sb sqliteBuilder) columnDefinition(pColumn schema.Column) string { return lColumnDef } -func (sb sqliteBuilder) Insert(pTableName string, pColumns []string) string { +func (sb *sqliteBuilder) Insert(pTableName string, pColumns []string) string { return sb.generic.Insert(pTableName, pColumns) } -func (sb sqliteBuilder) InsertWithValues(pTableName string, pColumnsValues ormshift.ColumnsValues) (string, []any) { +func (sb *sqliteBuilder) InsertWithValues(pTableName string, pColumnsValues ormshift.ColumnsValues) (string, []any) { return sb.generic.InsertWithValues(pTableName, pColumnsValues) } -func (sb sqliteBuilder) Update(pTableName string, pColumns, pColumnsWhere []string) string { +func (sb *sqliteBuilder) Update(pTableName string, pColumns, pColumnsWhere []string) string { return sb.generic.Update(pTableName, pColumns, pColumnsWhere) } -func (sb sqliteBuilder) UpdateWithValues(pTableName string, pColumns, pColumnsWhere []string, pValues ormshift.ColumnsValues) (string, []any) { +func (sb *sqliteBuilder) UpdateWithValues(pTableName string, pColumns, pColumnsWhere []string, pValues ormshift.ColumnsValues) (string, []any) { return sb.generic.UpdateWithValues(pTableName, pColumns, pColumnsWhere, pValues) } -func (sb sqliteBuilder) Delete(pTableName string, pColumnsWhere []string) string { +func (sb *sqliteBuilder) Delete(pTableName string, pColumnsWhere []string) string { return sb.generic.Delete(pTableName, pColumnsWhere) } -func (sb sqliteBuilder) DeleteWithValues(pTableName string, pWhereColumnsValues ormshift.ColumnsValues) (string, []any) { +func (sb *sqliteBuilder) DeleteWithValues(pTableName string, pWhereColumnsValues ormshift.ColumnsValues) (string, []any) { return sb.generic.DeleteWithValues(pTableName, pWhereColumnsValues) } -func (sb sqliteBuilder) Select(pTableName string, pColumns, pColumnsWhere []string) string { +func (sb *sqliteBuilder) Select(pTableName string, pColumns, pColumnsWhere []string) string { return sb.generic.Select(pTableName, pColumns, pColumnsWhere) } -func (sb sqliteBuilder) SelectWithValues(pTableName string, pColumns []string, pWhereColumnsValues ormshift.ColumnsValues) (string, []any) { +func (sb *sqliteBuilder) SelectWithValues(pTableName string, pColumns []string, pWhereColumnsValues ormshift.ColumnsValues) (string, []any) { return sb.generic.SelectWithValues(pTableName, pColumns, pWhereColumnsValues) } -func (sb sqliteBuilder) SelectWithPagination(pSQLSelectCommand string, pRowsPerPage, pPageNumber uint) string { +func (sb *sqliteBuilder) SelectWithPagination(pSQLSelectCommand string, pRowsPerPage, pPageNumber uint) string { return sb.generic.SelectWithPagination(pSQLSelectCommand, pRowsPerPage, pPageNumber) } -func (sb sqliteBuilder) QuoteIdentifier(pIdentifier string) string { +func (sb *sqliteBuilder) QuoteIdentifier(pIdentifier string) string { return sb.generic.QuoteIdentifier(pIdentifier) } -func (sb sqliteBuilder) InteroperateSQLCommandWithNamedArgs(pSQLCommand string, pNamedArgs ...sql.NamedArg) (string, []any) { +func (sb *sqliteBuilder) InteroperateSQLCommandWithNamedArgs(pSQLCommand string, pNamedArgs ...sql.NamedArg) (string, []any) { return sb.generic.InteroperateSQLCommandWithNamedArgs(pSQLCommand, pNamedArgs...) } diff --git a/dialects/sqlite/driver.go b/dialects/sqlite/driver.go index d8f28a3..488aaa4 100644 --- a/dialects/sqlite/driver.go +++ b/dialects/sqlite/driver.go @@ -16,16 +16,16 @@ type sqliteDriver struct { } func Driver() ormshift.DatabaseDriver { - return sqliteDriver{ + return &sqliteDriver{ sqlBuilder: newSQLiteBuilder(), } } -func (d sqliteDriver) Name() string { +func (d *sqliteDriver) Name() string { return "sqlite" } -func (d sqliteDriver) ConnectionString(pParams ormshift.ConnectionParams) string { +func (d *sqliteDriver) ConnectionString(pParams ormshift.ConnectionParams) string { if pParams.InMemory { return ":memory:" } @@ -39,10 +39,10 @@ func (d sqliteDriver) ConnectionString(pParams ormshift.ConnectionParams) string return fmt.Sprintf("file:%s.db?%s_locking=NORMAL", pParams.Database, lConnetionWithAuth) } -func (d sqliteDriver) SQLBuilder() ormshift.SQLBuilder { +func (d *sqliteDriver) SQLBuilder() ormshift.SQLBuilder { return d.sqlBuilder } -func (d sqliteDriver) DBSchema(pDB *sql.DB) (*schema.DBSchema, error) { +func (d *sqliteDriver) DBSchema(pDB *sql.DB) (*schema.DBSchema, error) { return schema.NewDBSchema(pDB, tableNamesQuery, columnTypesQueryFunc(d.sqlBuilder)) } diff --git a/dialects/sqlserver/builder.go b/dialects/sqlserver/builder.go index 6016739..6516aa5 100644 --- a/dialects/sqlserver/builder.go +++ b/dialects/sqlserver/builder.go @@ -17,10 +17,10 @@ type sqlserverBuilder struct { func newSQLServerBuilder() ormshift.SQLBuilder { sb := sqlserverBuilder{} sb.generic = internal.NewGenericSQLBuilder(sb.columnDefinition, sb.QuoteIdentifier, nil) - return sb + return &sb } -func (sb sqlserverBuilder) CreateTable(pTable schema.Table) string { +func (sb *sqlserverBuilder) CreateTable(pTable schema.Table) string { lColumns := "" lPKColumns := "" for _, lColumn := range pTable.Columns() { @@ -47,19 +47,19 @@ func (sb sqlserverBuilder) CreateTable(pTable schema.Table) string { return fmt.Sprintf("CREATE TABLE %s (%s);", sb.QuoteIdentifier(pTable.Name()), lColumns) } -func (sb sqlserverBuilder) DropTable(pTableName string) string { +func (sb *sqlserverBuilder) DropTable(pTableName string) string { return sb.generic.DropTable(pTableName) } -func (sb sqlserverBuilder) AlterTableAddColumn(pTableName string, pColumn schema.Column) string { +func (sb *sqlserverBuilder) AlterTableAddColumn(pTableName string, pColumn schema.Column) string { return sb.generic.AlterTableAddColumn(pTableName, pColumn) } -func (sb sqlserverBuilder) AlterTableDropColumn(pTableName string, pColumnName string) string { +func (sb *sqlserverBuilder) AlterTableDropColumn(pTableName string, pColumnName string) string { return sb.generic.AlterTableDropColumn(pTableName, pColumnName) } -func (sb sqlserverBuilder) ColumnTypeAsString(pColumnType schema.ColumnType) string { +func (sb *sqlserverBuilder) ColumnTypeAsString(pColumnType schema.ColumnType) string { switch pColumnType { case schema.Varchar: return "VARCHAR" @@ -80,7 +80,7 @@ func (sb sqlserverBuilder) ColumnTypeAsString(pColumnType schema.ColumnType) str } } -func (sb sqlserverBuilder) columnDefinition(pColumn schema.Column) string { +func (sb *sqlserverBuilder) columnDefinition(pColumn schema.Column) string { lColumnDef := sb.QuoteIdentifier(pColumn.Name()) if pColumn.Type() == schema.Varchar { lColumnDef += fmt.Sprintf(" %s(%d)", sb.ColumnTypeAsString(pColumn.Type()), pColumn.Size()) @@ -96,39 +96,39 @@ func (sb sqlserverBuilder) columnDefinition(pColumn schema.Column) string { return lColumnDef } -func (sb sqlserverBuilder) Insert(pTableName string, pColumns []string) string { +func (sb *sqlserverBuilder) Insert(pTableName string, pColumns []string) string { return sb.generic.Insert(pTableName, pColumns) } -func (sb sqlserverBuilder) InsertWithValues(pTableName string, pColumnsValues ormshift.ColumnsValues) (string, []any) { +func (sb *sqlserverBuilder) InsertWithValues(pTableName string, pColumnsValues ormshift.ColumnsValues) (string, []any) { return sb.generic.InsertWithValues(pTableName, pColumnsValues) } -func (sb sqlserverBuilder) Update(pTableName string, pColumns, pColumnsWhere []string) string { +func (sb *sqlserverBuilder) Update(pTableName string, pColumns, pColumnsWhere []string) string { return sb.generic.Update(pTableName, pColumns, pColumnsWhere) } -func (sb sqlserverBuilder) UpdateWithValues(pTableName string, pColumns, pColumnsWhere []string, pValues ormshift.ColumnsValues) (string, []any) { +func (sb *sqlserverBuilder) UpdateWithValues(pTableName string, pColumns, pColumnsWhere []string, pValues ormshift.ColumnsValues) (string, []any) { return sb.generic.UpdateWithValues(pTableName, pColumns, pColumnsWhere, pValues) } -func (sb sqlserverBuilder) Delete(pTableName string, pColumnsWhere []string) string { +func (sb *sqlserverBuilder) Delete(pTableName string, pColumnsWhere []string) string { return sb.generic.Delete(pTableName, pColumnsWhere) } -func (sb sqlserverBuilder) DeleteWithValues(pTableName string, pWhereColumnsValues ormshift.ColumnsValues) (string, []any) { +func (sb *sqlserverBuilder) DeleteWithValues(pTableName string, pWhereColumnsValues ormshift.ColumnsValues) (string, []any) { return sb.generic.DeleteWithValues(pTableName, pWhereColumnsValues) } -func (sb sqlserverBuilder) Select(pTableName string, pColumns, pColumnsWhere []string) string { +func (sb *sqlserverBuilder) Select(pTableName string, pColumns, pColumnsWhere []string) string { return sb.generic.Select(pTableName, pColumns, pColumnsWhere) } -func (sb sqlserverBuilder) SelectWithValues(pTableName string, pColumns []string, pWhereColumnsValues ormshift.ColumnsValues) (string, []any) { +func (sb *sqlserverBuilder) SelectWithValues(pTableName string, pColumns []string, pWhereColumnsValues ormshift.ColumnsValues) (string, []any) { return sb.generic.SelectWithValues(pTableName, pColumns, pWhereColumnsValues) } -func (sb sqlserverBuilder) SelectWithPagination(pSQLSelectCommand string, pRowsPerPage, pPageNumber uint) string { +func (sb *sqlserverBuilder) SelectWithPagination(pSQLSelectCommand string, pRowsPerPage, pPageNumber uint) string { lSelectWithPagination := pSQLSelectCommand if pRowsPerPage > 0 { lOffSet := uint(0) @@ -140,7 +140,7 @@ func (sb sqlserverBuilder) SelectWithPagination(pSQLSelectCommand string, pRowsP return lSelectWithPagination } -func (sb sqlserverBuilder) QuoteIdentifier(pIdentifier string) string { +func (sb *sqlserverBuilder) QuoteIdentifier(pIdentifier string) string { // SQL Server uses square brackets: [identifier] // Escape rule: ] becomes ]] // Example: users -> [users], table]name -> [table]]name] @@ -148,6 +148,6 @@ func (sb sqlserverBuilder) QuoteIdentifier(pIdentifier string) string { return fmt.Sprintf("[%s]", pIdentifier) } -func (sb sqlserverBuilder) InteroperateSQLCommandWithNamedArgs(pSQLCommand string, pNamedArgs ...sql.NamedArg) (string, []any) { +func (sb *sqlserverBuilder) InteroperateSQLCommandWithNamedArgs(pSQLCommand string, pNamedArgs ...sql.NamedArg) (string, []any) { return sb.generic.InteroperateSQLCommandWithNamedArgs(pSQLCommand, pNamedArgs...) } diff --git a/dialects/sqlserver/driver.go b/dialects/sqlserver/driver.go index f66f709..e415ade 100644 --- a/dialects/sqlserver/driver.go +++ b/dialects/sqlserver/driver.go @@ -16,16 +16,16 @@ type sqlserverDriver struct { } func Driver() ormshift.DatabaseDriver { - return sqlserverDriver{ + return &sqlserverDriver{ sqlBuilder: newSQLServerBuilder(), } } -func (d sqlserverDriver) Name() string { +func (d *sqlserverDriver) Name() string { return "sqlserver" } -func (d sqlserverDriver) ConnectionString(pParams ormshift.ConnectionParams) string { +func (d *sqlserverDriver) ConnectionString(pParams ormshift.ConnectionParams) string { lHostInstanceAndPort := pParams.Host if pParams.Instance != "" { lHostInstanceAndPort = fmt.Sprintf("%s\\%s", pParams.Host, pParams.Instance) @@ -36,10 +36,10 @@ func (d sqlserverDriver) ConnectionString(pParams ormshift.ConnectionParams) str return fmt.Sprintf("server=%s;user id=%s;password=%s;database=%s", lHostInstanceAndPort, pParams.User, pParams.Password, pParams.Database) } -func (d sqlserverDriver) SQLBuilder() ormshift.SQLBuilder { +func (d *sqlserverDriver) SQLBuilder() ormshift.SQLBuilder { return d.sqlBuilder } -func (d sqlserverDriver) DBSchema(pDB *sql.DB) (*schema.DBSchema, error) { +func (d *sqlserverDriver) DBSchema(pDB *sql.DB) (*schema.DBSchema, error) { return schema.NewDBSchema(pDB, tableNamesQuery, columnTypesQueryFunc(d.sqlBuilder)) } diff --git a/internal/builder_generic.go b/internal/builder_generic.go index 36c99a1..c49d939 100644 --- a/internal/builder_generic.go +++ b/internal/builder_generic.go @@ -26,14 +26,14 @@ func NewGenericSQLBuilder( pQuoteIdentifierFunc QuoteIdentifierFunc, pInteroperateSQLCommandWithNamedArgsFunc InteroperateSQLCommandWithNamedArgsFunc, ) ormshift.SQLBuilder { - return genericSQLBuilder{ + return &genericSQLBuilder{ ColumnDefinitionFunc: pColumnDefinitionFunc, QuoteIdentifierFunc: pQuoteIdentifierFunc, InteroperateSQLCommandWithNamedArgsFunc: pInteroperateSQLCommandWithNamedArgsFunc, } } -func (sb genericSQLBuilder) CreateTable(pTable schema.Table) string { +func (sb *genericSQLBuilder) CreateTable(pTable schema.Table) string { lColumns := "" lPKColumns := "" for _, lColumn := range pTable.Columns() { @@ -59,41 +59,41 @@ func (sb genericSQLBuilder) CreateTable(pTable schema.Table) string { return fmt.Sprintf("CREATE TABLE %s (%s);", sb.QuoteIdentifier(pTable.Name()), lColumns) } -func (sb genericSQLBuilder) DropTable(pTableName string) string { +func (sb *genericSQLBuilder) DropTable(pTableName string) string { return fmt.Sprintf("DROP TABLE %s;", sb.QuoteIdentifier(pTableName)) } -func (sb genericSQLBuilder) AlterTableAddColumn(pTableName string, pColumn schema.Column) string { +func (sb *genericSQLBuilder) AlterTableAddColumn(pTableName string, pColumn schema.Column) string { return fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s;", sb.QuoteIdentifier(pTableName), sb.columnDefinition(pColumn)) } -func (sb genericSQLBuilder) AlterTableDropColumn(pTableName string, pColumnName string) string { +func (sb *genericSQLBuilder) AlterTableDropColumn(pTableName string, pColumnName string) string { return fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s;", sb.QuoteIdentifier(pTableName), sb.QuoteIdentifier(pColumnName)) } -func (sb genericSQLBuilder) ColumnTypeAsString(pColumnType schema.ColumnType) string { +func (sb *genericSQLBuilder) ColumnTypeAsString(pColumnType schema.ColumnType) string { // Generic implementation, should be overridden by specific SQL builders return fmt.Sprintf("<>", pColumnType) } -func (sb genericSQLBuilder) columnDefinition(pColumn schema.Column) string { +func (sb *genericSQLBuilder) columnDefinition(pColumn schema.Column) string { if sb.ColumnDefinitionFunc != nil { return sb.ColumnDefinitionFunc(pColumn) } return fmt.Sprintf("%s %s", sb.QuoteIdentifier(pColumn.Name()), sb.ColumnTypeAsString(pColumn.Type())) } -func (sb genericSQLBuilder) Insert(pTableName string, pColumns []string) string { +func (sb *genericSQLBuilder) Insert(pTableName string, pColumns []string) string { return fmt.Sprintf("insert into %s (%s) values (%s)", sb.QuoteIdentifier(pTableName), sb.columnsList(pColumns), sb.namesList(pColumns)) } -func (sb genericSQLBuilder) InsertWithValues(pTableName string, pColumnsValues ormshift.ColumnsValues) (string, []any) { +func (sb *genericSQLBuilder) InsertWithValues(pTableName string, pColumnsValues ormshift.ColumnsValues) (string, []any) { lInsertSQL := sb.Insert(pTableName, pColumnsValues.ToColumns()) lInsertArgs := pColumnsValues.ToNamedArgs() return sb.InteroperateSQLCommandWithNamedArgs(lInsertSQL, lInsertArgs...) } -func (sb genericSQLBuilder) Update(pTableName string, pColumns, pColumnsWhere []string) string { +func (sb *genericSQLBuilder) Update(pTableName string, pColumns, pColumnsWhere []string) string { lUpdate := fmt.Sprintf("update %s set %s ", sb.QuoteIdentifier(pTableName), sb.columnEqualNameList(pColumns, ",")) if len(pColumnsWhere) > 0 { lUpdate += fmt.Sprintf("where %s", sb.columnEqualNameList(pColumnsWhere, " and ")) // NOSONAR go:S1192 - duplicate tradeoff accepted @@ -101,13 +101,13 @@ func (sb genericSQLBuilder) Update(pTableName string, pColumns, pColumnsWhere [] return lUpdate } -func (sb genericSQLBuilder) UpdateWithValues(pTableName string, pColumns, pColumnsWhere []string, pValues ormshift.ColumnsValues) (string, []any) { +func (sb *genericSQLBuilder) UpdateWithValues(pTableName string, pColumns, pColumnsWhere []string, pValues ormshift.ColumnsValues) (string, []any) { lUpdateSQL := sb.Update(pTableName, pColumns, pColumnsWhere) lUpdateArgs := pValues.ToNamedArgs() return sb.InteroperateSQLCommandWithNamedArgs(lUpdateSQL, lUpdateArgs...) } -func (sb genericSQLBuilder) Delete(pTableName string, pColumnsWhere []string) string { +func (sb *genericSQLBuilder) Delete(pTableName string, pColumnsWhere []string) string { lDelete := fmt.Sprintf("delete from %s ", sb.QuoteIdentifier(pTableName)) if len(pColumnsWhere) > 0 { lDelete += fmt.Sprintf("where %s", sb.columnEqualNameList(pColumnsWhere, " and ")) // NOSONAR go:S1192 - duplicate tradeoff accepted @@ -115,13 +115,13 @@ func (sb genericSQLBuilder) Delete(pTableName string, pColumnsWhere []string) st return lDelete } -func (sb genericSQLBuilder) DeleteWithValues(pTableName string, pWhereColumnsValues ormshift.ColumnsValues) (string, []any) { +func (sb *genericSQLBuilder) DeleteWithValues(pTableName string, pWhereColumnsValues ormshift.ColumnsValues) (string, []any) { lDeleteSQL := sb.Delete(pTableName, pWhereColumnsValues.ToColumns()) lDeleteArgs := pWhereColumnsValues.ToNamedArgs() return sb.InteroperateSQLCommandWithNamedArgs(lDeleteSQL, lDeleteArgs...) } -func (sb genericSQLBuilder) Select(pTableName string, pColumns, pColumnsWhere []string) string { +func (sb *genericSQLBuilder) Select(pTableName string, pColumns, pColumnsWhere []string) string { lUpdate := fmt.Sprintf("select %s from %s ", sb.columnsList(pColumns), sb.QuoteIdentifier(pTableName)) if len(pColumnsWhere) > 0 { lUpdate += fmt.Sprintf("where %s", sb.columnEqualNameList(pColumnsWhere, " and ")) // NOSONAR go:S1192 - duplicate tradeoff accepted @@ -129,13 +129,13 @@ func (sb genericSQLBuilder) Select(pTableName string, pColumns, pColumnsWhere [] return lUpdate } -func (sb genericSQLBuilder) SelectWithValues(pTableName string, pColumns []string, pWhereColumnsValues ormshift.ColumnsValues) (string, []any) { +func (sb *genericSQLBuilder) SelectWithValues(pTableName string, pColumns []string, pWhereColumnsValues ormshift.ColumnsValues) (string, []any) { lSelectSQL := sb.Select(pTableName, pColumns, pWhereColumnsValues.ToColumns()) lSelectArgs := pWhereColumnsValues.ToNamedArgs() return sb.InteroperateSQLCommandWithNamedArgs(lSelectSQL, lSelectArgs...) } -func (sb genericSQLBuilder) SelectWithPagination(pSQLSelectCommand string, pRowsPerPage, pPageNumber uint) string { +func (sb *genericSQLBuilder) SelectWithPagination(pSQLSelectCommand string, pRowsPerPage, pPageNumber uint) string { lSelectWithPagination := pSQLSelectCommand if pRowsPerPage > 0 { lSelectWithPagination += fmt.Sprintf(" LIMIT %d", pRowsPerPage) @@ -146,11 +146,11 @@ func (sb genericSQLBuilder) SelectWithPagination(pSQLSelectCommand string, pRows return lSelectWithPagination } -func (sb genericSQLBuilder) columnsList(pColumns []string) string { +func (sb *genericSQLBuilder) columnsList(pColumns []string) string { return strings.Join(pColumns, ",") } -func (sb genericSQLBuilder) namesList(pColumns []string) string { +func (sb *genericSQLBuilder) namesList(pColumns []string) string { lNames := []string{} for _, lColumn := range pColumns { lNames = append(lNames, "@"+lColumn) @@ -158,7 +158,7 @@ func (sb genericSQLBuilder) namesList(pColumns []string) string { return strings.Join(lNames, ",") } -func (sb genericSQLBuilder) columnEqualNameList(pColumns []string, pSeparator string) string { +func (sb *genericSQLBuilder) columnEqualNameList(pColumns []string, pSeparator string) string { lColumnEqualNameList := "" for _, lColumn := range pColumns { if lColumnEqualNameList != "" { @@ -169,7 +169,7 @@ func (sb genericSQLBuilder) columnEqualNameList(pColumns []string, pSeparator st return lColumnEqualNameList } -func (sb genericSQLBuilder) QuoteIdentifier(pIdentifier string) string { +func (sb *genericSQLBuilder) QuoteIdentifier(pIdentifier string) string { if sb.QuoteIdentifierFunc != nil { return sb.QuoteIdentifierFunc(pIdentifier) } @@ -181,7 +181,7 @@ func (sb genericSQLBuilder) QuoteIdentifier(pIdentifier string) string { return fmt.Sprintf(`"%s"`, pIdentifier) } -func (sb genericSQLBuilder) InteroperateSQLCommandWithNamedArgs(pSQLCommand string, pNamedArgs ...sql.NamedArg) (string, []any) { +func (sb *genericSQLBuilder) InteroperateSQLCommandWithNamedArgs(pSQLCommand string, pNamedArgs ...sql.NamedArg) (string, []any) { if sb.InteroperateSQLCommandWithNamedArgsFunc != nil { return sb.InteroperateSQLCommandWithNamedArgsFunc(pSQLCommand, pNamedArgs...) } diff --git a/migrations/config.go b/migrations/config.go index 917ca9d..4adb1ef 100644 --- a/migrations/config.go +++ b/migrations/config.go @@ -7,41 +7,39 @@ type MigratorConfig struct { appliedAtColumn string } -func NewMigratorConfig() MigratorConfig { +func NewMigratorConfig() *MigratorConfig { lConfig := MigratorConfig{ tableName: "__ormshift_migrations", migrationNameColumn: "name", migrationNameMaxLength: 250, appliedAtColumn: "applied_at", } - return lConfig + return &lConfig } -func (mc MigratorConfig) WithTableName(pTableName string) MigratorConfig { +func (mc *MigratorConfig) WithTableName(pTableName string) *MigratorConfig { mc.tableName = pTableName return mc } - -func (mc MigratorConfig) WithColumnNames(pMigrationNameColumn, pAppliedAtColumn string) MigratorConfig { +func (mc *MigratorConfig) WithColumnNames(pMigrationNameColumn, pAppliedAtColumn string) *MigratorConfig { mc.migrationNameColumn = pMigrationNameColumn mc.appliedAtColumn = pAppliedAtColumn return mc } - -func (mc MigratorConfig) WithMigrationNameMaxLength(pMaxLength uint) MigratorConfig { +func (mc *MigratorConfig) WithMigrationNameMaxLength(pMaxLength uint) *MigratorConfig { mc.migrationNameMaxLength = pMaxLength return mc } -func (mc MigratorConfig) TableName() string { +func (mc *MigratorConfig) TableName() string { return mc.tableName } -func (mc MigratorConfig) MigrationNameColumn() string { +func (mc *MigratorConfig) MigrationNameColumn() string { return mc.migrationNameColumn } -func (mc MigratorConfig) MigrationNameMaxLength() uint { +func (mc *MigratorConfig) MigrationNameMaxLength() uint { return mc.migrationNameMaxLength } -func (mc MigratorConfig) AppliedAtColumn() string { +func (mc *MigratorConfig) AppliedAtColumn() string { return mc.appliedAtColumn } diff --git a/migrations/migrations.go b/migrations/migrations.go index 1e92128..d14fcd0 100644 --- a/migrations/migrations.go +++ b/migrations/migrations.go @@ -7,7 +7,7 @@ type Migration interface { Down(pMigrator *Migrator) error } -func Migrate(pDatabase *ormshift.Database, pConfig MigratorConfig, pMigrations ...Migration) (*Migrator, error) { +func Migrate(pDatabase *ormshift.Database, pConfig *MigratorConfig, pMigrations ...Migration) (*Migrator, error) { lMigrator, lError := NewMigrator(pDatabase, pConfig) if lError != nil { return nil, lError diff --git a/migrations/migrator.go b/migrations/migrator.go index 4a7d5b6..110750c 100644 --- a/migrations/migrator.go +++ b/migrations/migrator.go @@ -11,17 +11,20 @@ import ( type Migrator struct { database *ormshift.Database - config MigratorConfig + config *MigratorConfig migrations []Migration appliedMigrations map[string]bool } -func NewMigrator(pDatabase *ormshift.Database, pConfig MigratorConfig) (*Migrator, error) { +func NewMigrator(pDatabase *ormshift.Database, pConfig *MigratorConfig) (*Migrator, error) { if pDatabase == nil { return nil, fmt.Errorf("database cannot be nil") } + if pConfig == nil { + return nil, fmt.Errorf("migrator config cannot be nil") + } - lAppliedMigrationNames, lError := getAppliedMigrationNames(*pDatabase, pConfig) + lAppliedMigrationNames, lError := getAppliedMigrationNames(pDatabase, pConfig) if lError != nil { return nil, fmt.Errorf("failed to get applied migration names: %w", lError) } @@ -80,16 +83,15 @@ func (m *Migrator) RevertLastAppliedMigration() error { return nil } -func (m Migrator) Database() *ormshift.Database { +func (m *Migrator) Database() *ormshift.Database { return m.database } -func (m Migrator) Migrations() []Migration { +func (m *Migrator) Migrations() []Migration { return m.migrations } -func (m Migrator) AppliedMigrations() []Migration { - +func (m *Migrator) AppliedMigrations() []Migration { lMigrations := []Migration{} for _, migration := range m.Migrations() { name := reflect.TypeOf(migration).Name() @@ -100,12 +102,12 @@ func (m Migrator) AppliedMigrations() []Migration { return lMigrations } -func (m Migrator) isApplied(pMigrationName string) bool { +func (m *Migrator) isApplied(pMigrationName string) bool { _, exists := m.appliedMigrations[pMigrationName] return exists } -func (m Migrator) recordAppliedMigration(pMigrationName string) error { +func (m *Migrator) recordAppliedMigration(pMigrationName string) error { q, p := m.database.SQLBuilder().InsertWithValues( m.config.tableName, ormshift.ColumnsValues{ @@ -117,7 +119,7 @@ func (m Migrator) recordAppliedMigration(pMigrationName string) error { return lError } -func (m Migrator) deleteAppliedMigration(pMigrationName string) error { +func (m *Migrator) deleteAppliedMigration(pMigrationName string) error { q, p := m.database.SQLBuilder().DeleteWithValues( m.config.tableName, ormshift.ColumnsValues{ @@ -128,7 +130,7 @@ func (m Migrator) deleteAppliedMigration(pMigrationName string) error { return lError } -func getAppliedMigrationNames(pDatabase ormshift.Database, pConfig MigratorConfig) (rMigrationNames []string, rError error) { +func getAppliedMigrationNames(pDatabase *ormshift.Database, pConfig *MigratorConfig) (rMigrationNames []string, rError error) { rError = ensureMigrationsTableExists(pDatabase, pConfig) if rError != nil { return @@ -162,21 +164,21 @@ func getAppliedMigrationNames(pDatabase ormshift.Database, pConfig MigratorConfi return } -func ensureMigrationsTableExists(pDatabase ormshift.Database, pConfig MigratorConfig) error { - lMigrationsTable := schema.NewTable(pConfig.tableName) +func ensureMigrationsTableExists(pDatabase *ormshift.Database, pConfig *MigratorConfig) error { + lMigrationsTable := schema.NewTable(pConfig.TableName()) if pDatabase.DBSchema().HasTable(lMigrationsTable.Name()) { return nil } lError := lMigrationsTable.AddColumns( schema.NewColumnParams{ - Name: pConfig.migrationNameColumn, + Name: pConfig.MigrationNameColumn(), Type: schema.Varchar, - Size: pConfig.migrationNameMaxLength, + Size: pConfig.MigrationNameMaxLength(), PrimaryKey: true, NotNull: true, }, schema.NewColumnParams{ - Name: pConfig.appliedAtColumn, + Name: pConfig.AppliedAtColumn(), Type: schema.DateTime, NotNull: true, }, diff --git a/schema/column.go b/schema/column.go index 40f09c3..7eed88e 100644 --- a/schema/column.go +++ b/schema/column.go @@ -41,26 +41,26 @@ func NewColumn(pParams NewColumnParams) Column { } } -func (c Column) Name() string { +func (c *Column) Name() string { return c.name } -func (c Column) Type() ColumnType { +func (c *Column) Type() ColumnType { return c.columnType } -func (c Column) Size() uint { +func (c *Column) Size() uint { return c.size } -func (c Column) PrimaryKey() bool { +func (c *Column) PrimaryKey() bool { return c.primaryKey } -func (c Column) NotNull() bool { +func (c *Column) NotNull() bool { return c.notNull } -func (c Column) AutoIncrement() bool { +func (c *Column) AutoIncrement() bool { return c.autoIncrement } diff --git a/schema/schema.go b/schema/schema.go index 0449812..64922ce 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -30,7 +30,7 @@ func NewDBSchema( }, nil } -func (s DBSchema) HasTable(pTableName string) bool { +func (s *DBSchema) HasTable(pTableName string) bool { lTables, lError := s.fetchTableNames() if lError != nil { return false @@ -40,7 +40,7 @@ func (s DBSchema) HasTable(pTableName string) bool { }) } -func (s DBSchema) fetchTableNames() (rTableNames []string, rError error) { +func (s *DBSchema) fetchTableNames() (rTableNames []string, rError error) { lRows, rError := s.db.Query(s.tableNamesQuery) if rError != nil { return @@ -61,7 +61,7 @@ func (s DBSchema) fetchTableNames() (rTableNames []string, rError error) { return } -func (s DBSchema) HasColumn(pTableName string, pColumnName string) bool { +func (s *DBSchema) HasColumn(pTableName string, pColumnName string) bool { lColumnTypes, lError := s.fetchColumnTypes(pTableName) if lError != nil { return false @@ -71,7 +71,7 @@ func (s DBSchema) HasColumn(pTableName string, pColumnName string) bool { }) } -func (s DBSchema) fetchColumnTypes(pTableName string) (rColumnTypes []*sql.ColumnType, rError error) { +func (s *DBSchema) fetchColumnTypes(pTableName string) (rColumnTypes []*sql.ColumnType, rError error) { lRows, rError := s.db.Query(s.columnTypesQueryFunc(pTableName)) if rError != nil { return diff --git a/schema/table.go b/schema/table.go index b273c9b..c1a5ce2 100644 --- a/schema/table.go +++ b/schema/table.go @@ -18,15 +18,15 @@ func NewTable(pName string) Table { } } -func (t Table) Name() string { +func (t *Table) Name() string { return t.name } -func (t Table) Columns() []Column { +func (t *Table) Columns() []Column { return t.columns } -func (t Table) AddColumns(pParams ...NewColumnParams) error { +func (t *Table) AddColumns(pParams ...NewColumnParams) error { for _, lColParams := range pParams { lColumn := NewColumn(lColParams) lColumnAlreadyExists := slices.ContainsFunc(t.columns, func(c Column) bool { From e4402b0c536dd9e7bc09557b4f8e1794ae3bc71d Mon Sep 17 00:00:00 2001 From: Guilherme Raduenz Date: Tue, 27 Jan 2026 15:39:35 -0300 Subject: [PATCH 16/26] included quoting in expected produced SQL commands (AI agent) --- dialects/postgresql/builder_test.go | 30 +++++++++++++-------------- dialects/sqlite/builder_test.go | 28 ++++++++++++------------- dialects/sqlserver/builder_test.go | 30 +++++++++++++-------------- internal/builder_generic_test.go | 32 ++++++++++++++--------------- 4 files changed, 60 insertions(+), 60 deletions(-) diff --git a/dialects/postgresql/builder_test.go b/dialects/postgresql/builder_test.go index 0caf73b..62a838f 100644 --- a/dialects/postgresql/builder_test.go +++ b/dialects/postgresql/builder_test.go @@ -42,14 +42,14 @@ func TestCreateTable(t *testing.T) { lSQLBuilder := postgresql.Driver().SQLBuilder() lUserTable := testutils.FakeUserTable(t) - lExpectedSQL := "CREATE TABLE user (id BIGSERIAL NOT NULL,email VARCHAR(80) NOT NULL,name VARCHAR(50) NOT NULL," + - "password_hash VARCHAR(256),active SMALLINT,created_at TIMESTAMP(6),user_master BIGINT,master_user_id BIGINT," + - "licence_price NUMERIC(17,2),relevance DOUBLE PRECISION,photo BYTEA,any VARCHAR,PRIMARY KEY (id,email));" + lExpectedSQL := "CREATE TABLE \"user\" (\"id\" BIGSERIAL NOT NULL,\"email\" VARCHAR(80) NOT NULL,\"name\" VARCHAR(50) NOT NULL," + + "\"password_hash\" VARCHAR(256),\"active\" SMALLINT,\"created_at\" TIMESTAMP(6),\"user_master\" BIGINT,\"master_user_id\" BIGINT," + + "\"licence_price\" NUMERIC(17,2),\"relevance\" DOUBLE PRECISION,\"photo\" BYTEA,\"any\" VARCHAR,PRIMARY KEY (\"id\",\"email\"));" lReturnedSQL := lSQLBuilder.CreateTable(lUserTable) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.CreateTable") lProductAttributeTable := testutils.FakeProductAttributeTable(t) - lExpectedSQL = "CREATE TABLE product_attribute (product_id BIGINT NOT NULL,attribute_id BIGINT NOT NULL,value VARCHAR(75),position BIGINT,PRIMARY KEY (product_id,attribute_id));" + lExpectedSQL = "CREATE TABLE \"product_attribute\" (\"product_id\" BIGINT NOT NULL,\"attribute_id\" BIGINT NOT NULL,\"value\" VARCHAR(75),\"position\" BIGINT,PRIMARY KEY (\"product_id\",\"attribute_id\"));" lReturnedSQL = lSQLBuilder.CreateTable(lProductAttributeTable) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.CreateTable") } @@ -58,7 +58,7 @@ func TestDropTable(t *testing.T) { lSQLBuilder := postgresql.Driver().SQLBuilder() lUserTableName := testutils.FakeUserTableName(t) - lExpectedSQL := "DROP TABLE user;" + lExpectedSQL := "DROP TABLE \"user\";" lReturnedSQL := lSQLBuilder.DropTable(lUserTableName) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.DropTable") } @@ -68,7 +68,7 @@ func TestAlterTableAddColumn(t *testing.T) { lUserTableName := testutils.FakeUserTableName(t) lUpdatedAtColumn := testutils.FakeUpdatedAtColumn(t) - lExpectedSQL := "ALTER TABLE user ADD COLUMN updated_at TIMESTAMP(6);" + lExpectedSQL := "ALTER TABLE \"user\" ADD COLUMN \"updated_at\" TIMESTAMP(6);" lReturnedSQL := lSQLBuilder.AlterTableAddColumn(lUserTableName, lUpdatedAtColumn) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.AlterTableAddColumn") } @@ -78,7 +78,7 @@ func TestAlterTableDropColumn(t *testing.T) { lUserTableName := testutils.FakeUserTableName(t) lUpdatedAtColumnName := testutils.FakeUpdatedAtColumnName(t) - lExpectedSQL := "ALTER TABLE user DROP COLUMN updated_at;" + lExpectedSQL := "ALTER TABLE \"user\" DROP COLUMN \"updated_at\";" lReturnedSQL := lSQLBuilder.AlterTableDropColumn(lUserTableName, lUpdatedAtColumnName) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.AlterTableDropColumn") } @@ -87,7 +87,7 @@ func TestInsert(t *testing.T) { lSQLBuilder := postgresql.Driver().SQLBuilder() lReturnedSQL := lSQLBuilder.Insert("product", []string{"id", "sku", "name", "description"}) - lExpectedSQL := "insert into product (id,sku,name,description) values (@id,@sku,@name,@description)" + lExpectedSQL := "insert into \"product\" (\"id\",\"sku\",\"name\",\"description\") values (@id,@sku,@name,@description)" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.Insert") } @@ -95,7 +95,7 @@ func TestInsertWithValues(t *testing.T) { lSQLBuilder := postgresql.Driver().SQLBuilder() lReturnedSQL, lReturnedValues := lSQLBuilder.InsertWithValues("product", ormshift.ColumnsValues{"id": 1, "sku": "1.005.12.9", "name": "Trufa Sabor Amarula 30g Cacaushow"}) - lExpectedSQL := "insert into product (id,name,sku) values ($1,$2,$3)" + lExpectedSQL := "insert into \"product\" (\"id\",\"name\",\"sku\") values ($1,$2,$3)" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.InsertWithValues.SQL") testutils.AssertEqualWithLabel(t, 3, len(lReturnedValues), "SQLBuilder.InsertWithValues.Values") testutils.AssertEqualWithLabel(t, 1, lReturnedValues[0], "SQLBuilder.InsertWithValues.Values[0]") @@ -107,7 +107,7 @@ func TestUpdate(t *testing.T) { lSQLBuilder := postgresql.Driver().SQLBuilder() lReturnedSQL := lSQLBuilder.Update("product", []string{"sku", "name", "description"}, []string{"id"}) - lExpectedSQL := "update product set sku = @sku,name = @name,description = @description where id = @id" + lExpectedSQL := "update \"product\" set \"sku\" = @sku,\"name\" = @name,\"description\" = @description where \"id\" = @id" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.Update") } @@ -115,7 +115,7 @@ func TestUpdateWithValues(t *testing.T) { lSQLBuilder := postgresql.Driver().SQLBuilder() lReturnedSQL, lReturnedValues := lSQLBuilder.UpdateWithValues("product", []string{"sku", "name"}, []string{"id"}, ormshift.ColumnsValues{"id": 1, "sku": "1.005.12.5", "name": "Trufa Sabor Amarula 18g Cacaushow"}) - lExpectedSQL := "update product set sku = $3,name = $2 where id = $1" + lExpectedSQL := "update \"product\" set \"sku\" = $3,\"name\" = $2 where \"id\" = $1" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.UpdateWithValues.SQL") testutils.AssertEqualWithLabel(t, 3, len(lReturnedValues), "SQLBuilder.UpdateWithValues.Values") testutils.AssertEqualWithLabel(t, 1, lReturnedValues[0], "SQLBuilder.UpdateWithValues.Values[0]") @@ -127,7 +127,7 @@ func TestDelete(t *testing.T) { lSQLBuilder := postgresql.Driver().SQLBuilder() lReturnedSQL := lSQLBuilder.Delete("product", []string{"id"}) - lExpectedSQL := "delete from product where id = @id" + lExpectedSQL := "delete from \"product\" where \"id\" = @id" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.Delete") } @@ -135,7 +135,7 @@ func TestDeleteWithValues(t *testing.T) { lSQLBuilder := postgresql.Driver().SQLBuilder() lReturnedSQL, lReturnedValues := lSQLBuilder.DeleteWithValues("product", ormshift.ColumnsValues{"id": 1}) - lExpectedSQL := "delete from product where id = $1" + lExpectedSQL := "delete from \"product\" where \"id\" = $1" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.DeleteWithValues.SQL") testutils.AssertEqualWithLabel(t, 1, len(lReturnedValues), "SQLBuilder.DeleteWithValues.Values") testutils.AssertEqualWithLabel(t, 1, lReturnedValues[0], "SQLBuilder.DeleteWithValues.Values[0]") @@ -145,7 +145,7 @@ func TestSelect(t *testing.T) { lSQLBuilder := postgresql.Driver().SQLBuilder() lReturnedSQL := lSQLBuilder.Select("product", []string{"id", "name", "description"}, []string{"sku", "active"}) - lExpectedSQL := "select id,name,description from product where sku = @sku and active = @active" + lExpectedSQL := "select \"id\",\"name\",\"description\" from \"product\" where \"sku\" = @sku and \"active\" = @active" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.Select") } @@ -153,7 +153,7 @@ func TestSelectWithValues(t *testing.T) { lSQLBuilder := postgresql.Driver().SQLBuilder() lReturnedSQL, lReturnedValues := lSQLBuilder.SelectWithValues("product", []string{"id", "sku", "name", "description"}, ormshift.ColumnsValues{"category_id": 1, "active": true}) - lExpectedSQL := "select id,sku,name,description from product where active = $1 and category_id = $2" + lExpectedSQL := "select \"id\",\"sku\",\"name\",\"description\" from \"product\" where \"active\" = $1 and \"category_id\" = $2" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.SelectWithValues.SQL") testutils.AssertEqualWithLabel(t, 2, len(lReturnedValues), "SQLBuilder.SelectWithValues.Values") testutils.AssertEqualWithLabel(t, 1, lReturnedValues[0], "SQLBuilder.SelectWithValues.Values[0]") diff --git a/dialects/sqlite/builder_test.go b/dialects/sqlite/builder_test.go index 4876336..fb386c5 100644 --- a/dialects/sqlite/builder_test.go +++ b/dialects/sqlite/builder_test.go @@ -22,13 +22,13 @@ func TestCreateTable(t *testing.T) { lSQLBuilder := sqlite.Driver().SQLBuilder() lUserTable := testutils.FakeUserTable(t) - lExpectedSQL := "CREATE TABLE user (id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,email TEXT NOT NULL,name TEXT NOT NULL," + - "password_hash TEXT,active INTEGER,created_at DATETIME,user_master INTEGER,master_user_id INTEGER,licence_price REAL,relevance REAL,photo BLOB,any TEXT);" + lExpectedSQL := "CREATE TABLE \"user\" (\"id\" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,\"email\" TEXT NOT NULL,\"name\" TEXT NOT NULL," + + "\"password_hash\" TEXT,\"active\" INTEGER,\"created_at\" DATETIME,\"user_master\" INTEGER,\"master_user_id\" INTEGER,\"licence_price\" REAL,\"relevance\" REAL,\"photo\" BLOB,\"any\" TEXT);" lReturnedSQL := lSQLBuilder.CreateTable(lUserTable) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.CreateTable") lProductAttributeTable := testutils.FakeProductAttributeTable(t) - lExpectedSQL = "CREATE TABLE product_attribute (product_id INTEGER NOT NULL,attribute_id INTEGER NOT NULL,value TEXT,position INTEGER,CONSTRAINT PK_product_attribute PRIMARY KEY (product_id,attribute_id));" + lExpectedSQL = "CREATE TABLE \"product_attribute\" (\"product_id\" INTEGER NOT NULL,\"attribute_id\" INTEGER NOT NULL,\"value\" TEXT,\"position\" INTEGER,CONSTRAINT \"PK_product_attribute\" PRIMARY KEY (\"product_id\",\"attribute_id\"));" lReturnedSQL = lSQLBuilder.CreateTable(lProductAttributeTable) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.CreateTable") } @@ -37,7 +37,7 @@ func TestDropTable(t *testing.T) { lSQLBuilder := sqlite.Driver().SQLBuilder() lUserTableName := testutils.FakeUserTableName(t) - lExpectedSQL := "DROP TABLE user;" + lExpectedSQL := "DROP TABLE \"user\";" lReturnedSQL := lSQLBuilder.DropTable(lUserTableName) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.DropTable") } @@ -47,7 +47,7 @@ func TestAlterTableAddColumn(t *testing.T) { lUserTableName := testutils.FakeUserTableName(t) lUpdatedAtColumn := testutils.FakeUpdatedAtColumn(t) - lExpectedSQL := "ALTER TABLE user ADD COLUMN updated_at DATETIME;" + lExpectedSQL := "ALTER TABLE \"user\" ADD COLUMN \"updated_at\" DATETIME;" lReturnedSQL := lSQLBuilder.AlterTableAddColumn(lUserTableName, lUpdatedAtColumn) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.AlterTableAddColumn") } @@ -57,7 +57,7 @@ func TestAlterTableDropColumn(t *testing.T) { lUserTableName := testutils.FakeUserTableName(t) lUpdatedAtColumnName := testutils.FakeUpdatedAtColumnName(t) - lExpectedSQL := "ALTER TABLE user DROP COLUMN updated_at;" + lExpectedSQL := "ALTER TABLE \"user\" DROP COLUMN \"updated_at\";" lReturnedSQL := lSQLBuilder.AlterTableDropColumn(lUserTableName, lUpdatedAtColumnName) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.AlterTableDropColumn") } @@ -66,7 +66,7 @@ func TestInsert(t *testing.T) { lSQLBuilder := sqlite.Driver().SQLBuilder() lReturnedSQL := lSQLBuilder.Insert("product", []string{"id", "sku", "name", "description"}) - lExpectedSQL := "insert into product (id,sku,name,description) values (@id,@sku,@name,@description)" + lExpectedSQL := "insert into \"product\" (\"id\",\"sku\",\"name\",\"description\") values (@id,@sku,@name,@description)" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.Insert") } @@ -74,7 +74,7 @@ func TestInsertWithValues(t *testing.T) { lSQLBuilder := sqlite.Driver().SQLBuilder() lReturnedSQL, lReturnedValues := lSQLBuilder.InsertWithValues("product", ormshift.ColumnsValues{"id": 1, "sku": "1.005.12.9", "name": "Trufa Sabor Amarula 30g Cacaushow"}) - lExpectedSQL := "insert into product (id,name,sku) values (@id,@name,@sku)" + lExpectedSQL := "insert into \"product\" (\"id\",\"name\",\"sku\") values (@id,@name,@sku)" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.InsertWithValues.SQL") testutils.AssertEqualWithLabel(t, 3, len(lReturnedValues), "SQLBuilder.InsertWithValues.Values") testutils.AssertNamedArgEqualWithLabel(t, lReturnedValues[0], sql.NamedArg{Name: "id", Value: 1}, "SQLBuilder.InsertWithValues.Values[0]") @@ -86,7 +86,7 @@ func TestUpdate(t *testing.T) { lSQLBuilder := sqlite.Driver().SQLBuilder() lReturnedSQL := lSQLBuilder.Update("product", []string{"sku", "name", "description"}, []string{"id"}) - lExpectedSQL := "update product set sku = @sku,name = @name,description = @description where id = @id" + lExpectedSQL := "update \"product\" set \"sku\" = @sku,\"name\" = @name,\"description\" = @description where \"id\" = @id" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.Update") } @@ -94,7 +94,7 @@ func TestUpdateWithValues(t *testing.T) { lSQLBuilder := sqlite.Driver().SQLBuilder() lReturnedSQL, lReturnedValues := lSQLBuilder.UpdateWithValues("product", []string{"sku", "name"}, []string{"id"}, ormshift.ColumnsValues{"id": 1, "sku": "1.005.12.5", "name": "Trufa Sabor Amarula 18g Cacaushow"}) - lExpectedSQL := "update product set sku = @sku,name = @name where id = @id" + lExpectedSQL := "update \"product\" set \"sku\" = @sku,\"name\" = @name where \"id\" = @id" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.UpdateWithValues.SQL") testutils.AssertEqualWithLabel(t, 3, len(lReturnedValues), "SQLBuilder.UpdateWithValues.Values") testutils.AssertNamedArgEqualWithLabel(t, lReturnedValues[0], sql.NamedArg{Name: "id", Value: 1}, "SQLBuilder.UpdateWithValues.Values[0]") @@ -106,7 +106,7 @@ func TestDelete(t *testing.T) { lSQLBuilder := sqlite.Driver().SQLBuilder() lReturnedSQL := lSQLBuilder.Delete("product", []string{"id"}) - lExpectedSQL := "delete from product where id = @id" + lExpectedSQL := "delete from \"product\" where \"id\" = @id" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.Delete") } @@ -114,7 +114,7 @@ func TestDeleteWithValues(t *testing.T) { lSQLBuilder := sqlite.Driver().SQLBuilder() lReturnedSQL, lReturnedValues := lSQLBuilder.DeleteWithValues("product", ormshift.ColumnsValues{"id": 1}) - lExpectedSQL := "delete from product where id = @id" + lExpectedSQL := "delete from \"product\" where \"id\" = @id" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.DeleteWithValues.SQL") testutils.AssertEqualWithLabel(t, 1, len(lReturnedValues), "SQLBuilder.DeleteWithValues.Values") testutils.AssertNamedArgEqualWithLabel(t, lReturnedValues[0], sql.NamedArg{Name: "id", Value: 1}, "SQLBuilder.DeleteWithValues.Values[0]") @@ -124,7 +124,7 @@ func TestSelect(t *testing.T) { lSQLBuilder := sqlite.Driver().SQLBuilder() lReturnedSQL := lSQLBuilder.Select("product", []string{"id", "name", "description"}, []string{"sku", "active"}) - lExpectedSQL := "select id,name,description from product where sku = @sku and active = @active" + lExpectedSQL := "select \"id\",\"name\",\"description\" from \"product\" where \"sku\" = @sku and \"active\" = @active" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.Select") } @@ -132,7 +132,7 @@ func TestSelectWithValues(t *testing.T) { lSQLBuilder := sqlite.Driver().SQLBuilder() lReturnedSQL, lReturnedValues := lSQLBuilder.SelectWithValues("product", []string{"id", "sku", "name", "description"}, ormshift.ColumnsValues{"category_id": 1, "active": true}) - lExpectedSQL := "select id,sku,name,description from product where active = @active and category_id = @category_id" + lExpectedSQL := "select \"id\",\"sku\",\"name\",\"description\" from \"product\" where \"active\" = @active and \"category_id\" = @category_id" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.SelectWithValues.SQL") testutils.AssertEqualWithLabel(t, 2, len(lReturnedValues), "SQLBuilder.SelectWithValues.Values") testutils.AssertNamedArgEqualWithLabel(t, lReturnedValues[0], sql.NamedArg{Name: "active", Value: true}, "SQLBuilder.SelectWithValues.Values[0]") diff --git a/dialects/sqlserver/builder_test.go b/dialects/sqlserver/builder_test.go index baadb47..37fe538 100644 --- a/dialects/sqlserver/builder_test.go +++ b/dialects/sqlserver/builder_test.go @@ -22,14 +22,14 @@ func TestCreateTable(t *testing.T) { lSQLBuilder := sqlserver.Driver().SQLBuilder() lUserTable := testutils.FakeUserTable(t) - lExpectedSQL := "CREATE TABLE user (id BIGINT NOT NULL IDENTITY (1, 1),email VARCHAR(80) NOT NULL,name VARCHAR(50) NOT NULL," + - "password_hash VARCHAR(256),active BIT,created_at DATETIME2(6),user_master BIGINT,master_user_id BIGINT," + - "licence_price MONEY,relevance FLOAT,photo VARBINARY(MAX),any VARCHAR,CONSTRAINT PK_user PRIMARY KEY (id,email));" + lExpectedSQL := "CREATE TABLE [user] ([id] BIGINT NOT NULL IDENTITY (1, 1),[email] VARCHAR(80) NOT NULL,[name] VARCHAR(50) NOT NULL," + + "[password_hash] VARCHAR(256),[active] BIT,[created_at] DATETIME2(6),[user_master] BIGINT,[master_user_id] BIGINT," + + "[licence_price] MONEY,[relevance] FLOAT,[photo] VARBINARY(MAX),[any] VARCHAR,CONSTRAINT [PK_user] PRIMARY KEY ([id],[email]));" lReturnedSQL := lSQLBuilder.CreateTable(lUserTable) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.CreateTable") lProductAttributeTable := testutils.FakeProductAttributeTable(t) - lExpectedSQL = "CREATE TABLE product_attribute (product_id BIGINT NOT NULL,attribute_id BIGINT NOT NULL,value VARCHAR(75),position BIGINT,CONSTRAINT PK_product_attribute PRIMARY KEY (product_id,attribute_id));" + lExpectedSQL = "CREATE TABLE [product_attribute] ([product_id] BIGINT NOT NULL,[attribute_id] BIGINT NOT NULL,[value] VARCHAR(75),[position] BIGINT,CONSTRAINT [PK_product_attribute] PRIMARY KEY ([product_id],[attribute_id]));" lReturnedSQL = lSQLBuilder.CreateTable(lProductAttributeTable) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.CreateTable") } @@ -38,7 +38,7 @@ func TestDropTable(t *testing.T) { lSQLBuilder := sqlserver.Driver().SQLBuilder() lUserTableName := testutils.FakeUserTableName(t) - lExpectedSQL := "DROP TABLE user;" + lExpectedSQL := "DROP TABLE [user];" lReturnedSQL := lSQLBuilder.DropTable(lUserTableName) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.DropTable") } @@ -48,7 +48,7 @@ func TestAlterTableAddColumn(t *testing.T) { lUserTableName := testutils.FakeUserTableName(t) lUpdatedAtColumn := testutils.FakeUpdatedAtColumn(t) - lExpectedSQL := "ALTER TABLE user ADD COLUMN updated_at DATETIME2(6);" + lExpectedSQL := "ALTER TABLE [user] ADD COLUMN [updated_at] DATETIME2(6);" lReturnedSQL := lSQLBuilder.AlterTableAddColumn(lUserTableName, lUpdatedAtColumn) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.AlterTableAddColumn") } @@ -58,7 +58,7 @@ func TestAlterTableDropColumn(t *testing.T) { lUserTableName := testutils.FakeUserTableName(t) lUpdatedAtColumnName := testutils.FakeUpdatedAtColumnName(t) - lExpectedSQL := "ALTER TABLE user DROP COLUMN updated_at;" + lExpectedSQL := "ALTER TABLE [user] DROP COLUMN [updated_at];" lReturnedSQL := lSQLBuilder.AlterTableDropColumn(lUserTableName, lUpdatedAtColumnName) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.AlterTableDropColumn") } @@ -67,7 +67,7 @@ func TestInsert(t *testing.T) { lSQLBuilder := sqlserver.Driver().SQLBuilder() lReturnedSQL := lSQLBuilder.Insert("product", []string{"id", "sku", "name", "description"}) - lExpectedSQL := "insert into product (id,sku,name,description) values (@id,@sku,@name,@description)" + lExpectedSQL := "insert into [product] ([id],[sku],[name],[description]) values (@id,@sku,@name,@description)" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.Insert") } @@ -75,7 +75,7 @@ func TestInsertWithValues(t *testing.T) { lSQLBuilder := sqlserver.Driver().SQLBuilder() lReturnedSQL, lReturnedValues := lSQLBuilder.InsertWithValues("product", ormshift.ColumnsValues{"id": 1, "sku": "1.005.12.9", "name": "Trufa Sabor Amarula 30g Cacaushow"}) - lExpectedSQL := "insert into product (id,name,sku) values (@id,@name,@sku)" + lExpectedSQL := "insert into [product] ([id],[name],[sku]) values (@id,@name,@sku)" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.InsertWithValues.SQL") testutils.AssertEqualWithLabel(t, 3, len(lReturnedValues), "SQLBuilder.InsertWithValues.Values") testutils.AssertNamedArgEqualWithLabel(t, lReturnedValues[0], sql.NamedArg{Name: "id", Value: 1}, "SQLBuilder.InsertWithValues.Values[0]") @@ -87,7 +87,7 @@ func TestUpdate(t *testing.T) { lSQLBuilder := sqlserver.Driver().SQLBuilder() lReturnedSQL := lSQLBuilder.Update("product", []string{"sku", "name", "description"}, []string{"id"}) - lExpectedSQL := "update product set sku = @sku,name = @name,description = @description where id = @id" + lExpectedSQL := "update [product] set [sku] = @sku,[name] = @name,[description] = @description where [id] = @id" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.Update") } @@ -95,7 +95,7 @@ func TestUpdateWithValues(t *testing.T) { lSQLBuilder := sqlserver.Driver().SQLBuilder() lReturnedSQL, lReturnedValues := lSQLBuilder.UpdateWithValues("product", []string{"sku", "name"}, []string{"id"}, ormshift.ColumnsValues{"id": 1, "sku": "1.005.12.5", "name": "Trufa Sabor Amarula 18g Cacaushow"}) - lExpectedSQL := "update product set sku = @sku,name = @name where id = @id" + lExpectedSQL := "update [product] set [sku] = @sku,[name] = @name where [id] = @id" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.UpdateWithValues.SQL") testutils.AssertEqualWithLabel(t, 3, len(lReturnedValues), "SQLBuilder.UpdateWithValues.Values") testutils.AssertNamedArgEqualWithLabel(t, lReturnedValues[0], sql.NamedArg{Name: "id", Value: 1}, "SQLBuilder.UpdateWithValues.Values[0]") @@ -107,7 +107,7 @@ func TestDelete(t *testing.T) { lSQLBuilder := sqlserver.Driver().SQLBuilder() lReturnedSQL := lSQLBuilder.Delete("product", []string{"id"}) - lExpectedSQL := "delete from product where id = @id" + lExpectedSQL := "delete from [product] where [id] = @id" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.Delete") } @@ -115,7 +115,7 @@ func TestDeleteWithValues(t *testing.T) { lSQLBuilder := sqlserver.Driver().SQLBuilder() lReturnedSQL, lReturnedValues := lSQLBuilder.DeleteWithValues("product", ormshift.ColumnsValues{"id": 1}) - lExpectedSQL := "delete from product where id = @id" + lExpectedSQL := "delete from [product] where [id] = @id" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.DeleteWithValues.SQL") testutils.AssertEqualWithLabel(t, 1, len(lReturnedValues), "SQLBuilder.DeleteWithValues.Values") testutils.AssertNamedArgEqualWithLabel(t, lReturnedValues[0], sql.NamedArg{Name: "id", Value: 1}, "SQLBuilder.DeleteWithValues.Values[0]") @@ -125,7 +125,7 @@ func TestSelect(t *testing.T) { lSQLBuilder := sqlserver.Driver().SQLBuilder() lReturnedSQL := lSQLBuilder.Select("product", []string{"id", "name", "description"}, []string{"sku", "active"}) - lExpectedSQL := "select id,name,description from product where sku = @sku and active = @active" + lExpectedSQL := "select [id],[name],[description] from [product] where [sku] = @sku and [active] = @active" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.Select") } @@ -133,7 +133,7 @@ func TestSelectWithValues(t *testing.T) { lSQLBuilder := sqlserver.Driver().SQLBuilder() lReturnedSQL, lReturnedValues := lSQLBuilder.SelectWithValues("product", []string{"id", "sku", "name", "description"}, ormshift.ColumnsValues{"category_id": 1, "active": true}) - lExpectedSQL := "select id,sku,name,description from product where active = @active and category_id = @category_id" + lExpectedSQL := "select [id],[sku],[name],[description] from [product] where [active] = @active and [category_id] = @category_id" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.SelectWithValues.SQL") testutils.AssertEqualWithLabel(t, 2, len(lReturnedValues), "SQLBuilder.SelectWithValues.Values") testutils.AssertNamedArgEqualWithLabel(t, lReturnedValues[0], sql.NamedArg{Name: "active", Value: true}, "SQLBuilder.SelectWithValues.Values[0]") diff --git a/internal/builder_generic_test.go b/internal/builder_generic_test.go index 50589e2..1ead992 100644 --- a/internal/builder_generic_test.go +++ b/internal/builder_generic_test.go @@ -14,14 +14,14 @@ func TestCreateTable(t *testing.T) { lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil, nil) lUserTable := testutils.FakeUserTable(t) - lExpectedSQL := "CREATE TABLE user (id <>,email <>,name <>,password_hash <>," + - "active <>,created_at <>,user_master <>,master_user_id <>," + - "licence_price <>,relevance <>,photo <>,any <>,PRIMARY KEY (id,email));" + lExpectedSQL := "CREATE TABLE \"user\" (\"id\" <>,\"email\" <>,\"name\" <>,\"password_hash\" <>," + + "\"active\" <>,\"created_at\" <>,\"user_master\" <>,\"master_user_id\" <>," + + "\"licence_price\" <>,\"relevance\" <>,\"photo\" <>,\"any\" <>,PRIMARY KEY (\"id\",\"email\"));" lReturnedSQL := lSQLBuilder.CreateTable(lUserTable) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.CreateTable") lProductAttributeTable := testutils.FakeProductAttributeTable(t) - lExpectedSQL = "CREATE TABLE product_attribute (product_id <>,attribute_id <>,value <>,position <>,PRIMARY KEY (product_id,attribute_id));" + lExpectedSQL = "CREATE TABLE \"product_attribute\" (\"product_id\" <>,\"attribute_id\" <>,\"value\" <>,\"position\" <>,PRIMARY KEY (\"product_id\",\"attribute_id\"));" lReturnedSQL = lSQLBuilder.CreateTable(lProductAttributeTable) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.CreateTable") } @@ -30,7 +30,7 @@ func TestDropTable(t *testing.T) { lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil, nil) lUserTableName := testutils.FakeUserTableName(t) - lExpectedSQL := "DROP TABLE user;" + lExpectedSQL := "DROP TABLE \"user\";" lReturnedSQL := lSQLBuilder.DropTable(lUserTableName) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.DropTable") } @@ -40,7 +40,7 @@ func TestAlterTableAddColumn(t *testing.T) { lUserTableName := testutils.FakeUserTableName(t) lUpdatedAtColumn := testutils.FakeUpdatedAtColumn(t) - lExpectedSQL := "ALTER TABLE user ADD COLUMN updated_at <>;" + lExpectedSQL := "ALTER TABLE \"user\" ADD COLUMN \"updated_at\" <>;" lReturnedSQL := lSQLBuilder.AlterTableAddColumn(lUserTableName, lUpdatedAtColumn) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.AlterTableAddColumn") } @@ -50,7 +50,7 @@ func TestAlterTableDropColumn(t *testing.T) { lUserTableName := testutils.FakeUserTableName(t) lUpdatedAtColumnName := testutils.FakeUpdatedAtColumnName(t) - lExpectedSQL := "ALTER TABLE user DROP COLUMN updated_at;" + lExpectedSQL := "ALTER TABLE \"user\" DROP COLUMN \"updated_at\";" lReturnedSQL := lSQLBuilder.AlterTableDropColumn(lUserTableName, lUpdatedAtColumnName) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.AlterTableDropColumn") } @@ -59,7 +59,7 @@ func TestInsert(t *testing.T) { lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil, nil) lReturnedSQL := lSQLBuilder.Insert("product", []string{"id", "sku", "name", "description"}) - lExpectedSQL := "insert into product (id,sku,name,description) values (@id,@sku,@name,@description)" + lExpectedSQL := "insert into \"product\" (\"id\",\"sku\",\"name\",\"description\") values (@id,@sku,@name,@description)" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.Insert") } @@ -67,7 +67,7 @@ func TestInsertWithValues(t *testing.T) { lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil, nil) lReturnedSQL, lReturnedValues := lSQLBuilder.InsertWithValues("product", ormshift.ColumnsValues{"id": 1, "sku": "1.005.12.9", "name": "Trufa Sabor Amarula 30g Cacaushow"}) - lExpectedSQL := "insert into product (id,name,sku) values (@id,@name,@sku)" + lExpectedSQL := "insert into \"product\" (\"id\",\"name\",\"sku\") values (@id,@name,@sku)" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.InsertWithValues.SQL") testutils.AssertEqualWithLabel(t, 3, len(lReturnedValues), "SQLBuilder.InsertWithValues.Values") testutils.AssertNamedArgEqualWithLabel(t, lReturnedValues[0], sql.NamedArg{Name: "id", Value: 1}, "SQLBuilder.InsertWithValues.Values[0]") @@ -79,7 +79,7 @@ func TestUpdate(t *testing.T) { lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil, nil) lReturnedSQL := lSQLBuilder.Update("product", []string{"sku", "name", "description"}, []string{"id"}) - lExpectedSQL := "update product set sku = @sku,name = @name,description = @description where id = @id" + lExpectedSQL := "update \"product\" set \"sku\" = @sku,\"name\" = @name,\"description\" = @description where \"id\" = @id" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.Update") } @@ -87,7 +87,7 @@ func TestUpdateWithValues(t *testing.T) { lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil, nil) lReturnedSQL, lReturnedValues := lSQLBuilder.UpdateWithValues("product", []string{"sku", "name"}, []string{"id"}, ormshift.ColumnsValues{"id": 1, "sku": "1.005.12.5", "name": "Trufa Sabor Amarula 18g Cacaushow"}) - lExpectedSQL := "update product set sku = @sku,name = @name where id = @id" + lExpectedSQL := "update \"product\" set \"sku\" = @sku,\"name\" = @name where \"id\" = @id" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.UpdateWithValues.SQL") testutils.AssertEqualWithLabel(t, 3, len(lReturnedValues), "SQLBuilder.UpdateWithValues.Values") testutils.AssertNamedArgEqualWithLabel(t, lReturnedValues[0], sql.NamedArg{Name: "id", Value: 1}, "SQLBuilder.UpdateWithValues.Values[0]") @@ -99,7 +99,7 @@ func TestDelete(t *testing.T) { lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil, nil) lReturnedSQL := lSQLBuilder.Delete("product", []string{"id"}) - lExpectedSQL := "delete from product where id = @id" + lExpectedSQL := "delete from \"product\" where \"id\" = @id" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.Delete") } @@ -107,7 +107,7 @@ func TestDeleteWithValues(t *testing.T) { lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil, nil) lReturnedSQL, lReturnedValues := lSQLBuilder.DeleteWithValues("product", ormshift.ColumnsValues{"id": 1}) - lExpectedSQL := "delete from product where id = @id" + lExpectedSQL := "delete from \"product\" where \"id\" = @id" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.DeleteWithValues.SQL") testutils.AssertEqualWithLabel(t, 1, len(lReturnedValues), "SQLBuilder.DeleteWithValues.Values") testutils.AssertNamedArgEqualWithLabel(t, lReturnedValues[0], sql.NamedArg{Name: "id", Value: 1}, "SQLBuilder.DeleteWithValues.Values[0]") @@ -117,7 +117,7 @@ func TestSelect(t *testing.T) { lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil, nil) lReturnedSQL := lSQLBuilder.Select("product", []string{"id", "name", "description"}, []string{"sku", "active"}) - lExpectedSQL := "select id,name,description from product where sku = @sku and active = @active" + lExpectedSQL := "select \"id\",\"name\",\"description\" from \"product\" where \"sku\" = @sku and \"active\" = @active" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.Select") } @@ -125,7 +125,7 @@ func TestSelectWithValues(t *testing.T) { lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil, nil) lReturnedSQL, lReturnedValues := lSQLBuilder.SelectWithValues("product", []string{"id", "sku", "name", "description"}, ormshift.ColumnsValues{"category_id": 1, "active": true}) - lExpectedSQL := "select id,sku,name,description from product where active = @active and category_id = @category_id" + lExpectedSQL := "select \"id\",\"sku\",\"name\",\"description\" from \"product\" where \"active\" = @active and \"category_id\" = @category_id" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.SelectWithValues.SQL") testutils.AssertEqualWithLabel(t, 2, len(lReturnedValues), "SQLBuilder.SelectWithValues.Values") testutils.AssertNamedArgEqualWithLabel(t, lReturnedValues[0], sql.NamedArg{Name: "active", Value: true}, "SQLBuilder.SelectWithValues.Values[0]") @@ -154,5 +154,5 @@ func TestColumnDefinition(t *testing.T) { lColumn := schema.NewColumn(schema.NewColumnParams{Name: "column_name", Type: schema.Integer, Size: 0}) lTableName := "test_table" lReturnedSQL := lSQLBuilder.AlterTableAddColumn(lTableName, lColumn) - testutils.AssertEqualWithLabel(t, "ALTER TABLE test_table ADD COLUMN fake;", lReturnedSQL, "SQLBuilder.ColumnDefinition") + testutils.AssertEqualWithLabel(t, "ALTER TABLE \"test_table\" ADD COLUMN fake;", lReturnedSQL, "SQLBuilder.ColumnDefinition") } From d320e1a993f535548125dfad1963751675447a10 Mon Sep 17 00:00:00 2001 From: Guilherme Raduenz Date: Tue, 27 Jan 2026 15:53:59 -0300 Subject: [PATCH 17/26] quoted missing identifiers --- dialects/sqlite/builder.go | 2 +- internal/builder_generic.go | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/dialects/sqlite/builder.go b/dialects/sqlite/builder.go index 6dff7d4..f26fdea 100644 --- a/dialects/sqlite/builder.go +++ b/dialects/sqlite/builder.go @@ -33,7 +33,7 @@ func (sb *sqliteBuilder) CreateTable(pTable schema.Table) string { if lPKColumns != "" { lPKColumns += "," } - lPKColumns += lColumn.Name() + lPKColumns += sb.QuoteIdentifier(lColumn.Name()) } if !lHasAutoIncrementColumn { diff --git a/internal/builder_generic.go b/internal/builder_generic.go index c49d939..5fb44b3 100644 --- a/internal/builder_generic.go +++ b/internal/builder_generic.go @@ -147,7 +147,11 @@ func (sb *genericSQLBuilder) SelectWithPagination(pSQLSelectCommand string, pRow } func (sb *genericSQLBuilder) columnsList(pColumns []string) string { - return strings.Join(pColumns, ",") + lQuotedColumns := []string{} + for _, col := range pColumns { + lQuotedColumns = append(lQuotedColumns, sb.QuoteIdentifier(col)) + } + return strings.Join(lQuotedColumns, ",") } func (sb *genericSQLBuilder) namesList(pColumns []string) string { @@ -164,7 +168,7 @@ func (sb *genericSQLBuilder) columnEqualNameList(pColumns []string, pSeparator s if lColumnEqualNameList != "" { lColumnEqualNameList += pSeparator } - lColumnEqualNameList += fmt.Sprintf("%s = @%s", lColumn, lColumn) + lColumnEqualNameList += fmt.Sprintf("%s = @%s", sb.QuoteIdentifier(lColumn), lColumn) } return lColumnEqualNameList } From c49294c464ad53ceffb28542f0bef363ce182564 Mon Sep 17 00:00:00 2001 From: Guilherme Raduenz Date: Tue, 27 Jan 2026 15:54:15 -0300 Subject: [PATCH 18/26] fixed broken test after changing structs to use pointer receivers --- database_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/database_test.go b/database_test.go index 8c98070..852b163 100644 --- a/database_test.go +++ b/database_test.go @@ -124,7 +124,7 @@ func TestDriverConnectionString(t *testing.T) { func TestDriverSQLBuilder(t *testing.T) { lDriver := testutils.NewFakeDriver(sqlite.Driver()) lSQLBuilder := lDriver.SQLBuilder() - testutils.AssertEqualWithLabel(t, "sqliteBuilder", reflect.TypeOf(lSQLBuilder).Name(), "FakeDriver.SQLBuilder") + testutils.AssertEqualWithLabel(t, "sqliteBuilder", reflect.TypeOf(lSQLBuilder).Elem().Name(), "FakeDriver.SQLBuilder") } func TestDriverDBSchema(t *testing.T) { From 5a4a3085c4ba8e2f86c8041dae2303c5ff2ffe38 Mon Sep 17 00:00:00 2001 From: Guilherme Raduenz Date: Tue, 27 Jan 2026 16:07:25 -0300 Subject: [PATCH 19/26] removed redundant executor from database --- database.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/database.go b/database.go index c641af1..7edc394 100644 --- a/database.go +++ b/database.go @@ -28,7 +28,6 @@ type DatabaseDriver interface { type Database struct { driver DatabaseDriver db *sql.DB - executor SQLExecutor connectionString string sqlBuilder SQLBuilder schema *schema.DBSchema @@ -48,12 +47,9 @@ func OpenDatabase(pDriver DatabaseDriver, pParams ConnectionParams) (*Database, return nil, fmt.Errorf("failed to get DB schema: %w", lError) } - // TODO: Unify SQLExecutor interface usage - var lExecutor SQLExecutor = lDB return &Database{ driver: pDriver, db: lDB, - executor: lExecutor, connectionString: lConnectionString, sqlBuilder: pDriver.SQLBuilder(), schema: lSchema, @@ -69,7 +65,7 @@ func (d *Database) DB() *sql.DB { } func (d *Database) SQLExecutor() SQLExecutor { - return d.executor + return d.db } func (d *Database) DriverName() string { From 9e5c65e7490f126fc6acc6d4859e952af6da85c7 Mon Sep 17 00:00:00 2001 From: Guilherme Raduenz Date: Tue, 27 Jan 2026 16:08:30 -0300 Subject: [PATCH 20/26] grouped parameters of same type together --- builder.go | 2 +- dialects/postgresql/builder.go | 2 +- dialects/sqlite/builder.go | 2 +- dialects/sqlserver/builder.go | 2 +- internal/builder_generic.go | 2 +- schema/schema.go | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/builder.go b/builder.go index 736383a..95ff114 100644 --- a/builder.go +++ b/builder.go @@ -12,7 +12,7 @@ type DDLSQLBuilder interface { CreateTable(pTable schema.Table) string DropTable(pTableName string) string AlterTableAddColumn(pTableName string, pColumn schema.Column) string - AlterTableDropColumn(pTableName string, pColumnName string) string + AlterTableDropColumn(pTableName, pColumnName string) string ColumnTypeAsString(pColumnType schema.ColumnType) string } diff --git a/dialects/postgresql/builder.go b/dialects/postgresql/builder.go index 9607eac..ba0fd99 100644 --- a/dialects/postgresql/builder.go +++ b/dialects/postgresql/builder.go @@ -33,7 +33,7 @@ func (sb *postgresqlBuilder) AlterTableAddColumn(pTableName string, pColumn sche return sb.generic.AlterTableAddColumn(pTableName, pColumn) } -func (sb *postgresqlBuilder) AlterTableDropColumn(pTableName string, pColumnName string) string { +func (sb *postgresqlBuilder) AlterTableDropColumn(pTableName, pColumnName string) string { return sb.generic.AlterTableDropColumn(pTableName, pColumnName) } diff --git a/dialects/sqlite/builder.go b/dialects/sqlite/builder.go index f26fdea..d158670 100644 --- a/dialects/sqlite/builder.go +++ b/dialects/sqlite/builder.go @@ -59,7 +59,7 @@ func (sb *sqliteBuilder) AlterTableAddColumn(pTableName string, pColumn schema.C return sb.generic.AlterTableAddColumn(pTableName, pColumn) } -func (sb *sqliteBuilder) AlterTableDropColumn(pTableName string, pColumnName string) string { +func (sb *sqliteBuilder) AlterTableDropColumn(pTableName, pColumnName string) string { return sb.generic.AlterTableDropColumn(pTableName, pColumnName) } diff --git a/dialects/sqlserver/builder.go b/dialects/sqlserver/builder.go index 6516aa5..12a9871 100644 --- a/dialects/sqlserver/builder.go +++ b/dialects/sqlserver/builder.go @@ -55,7 +55,7 @@ func (sb *sqlserverBuilder) AlterTableAddColumn(pTableName string, pColumn schem return sb.generic.AlterTableAddColumn(pTableName, pColumn) } -func (sb *sqlserverBuilder) AlterTableDropColumn(pTableName string, pColumnName string) string { +func (sb *sqlserverBuilder) AlterTableDropColumn(pTableName, pColumnName string) string { return sb.generic.AlterTableDropColumn(pTableName, pColumnName) } diff --git a/internal/builder_generic.go b/internal/builder_generic.go index 5fb44b3..1f061bd 100644 --- a/internal/builder_generic.go +++ b/internal/builder_generic.go @@ -67,7 +67,7 @@ func (sb *genericSQLBuilder) AlterTableAddColumn(pTableName string, pColumn sche return fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s;", sb.QuoteIdentifier(pTableName), sb.columnDefinition(pColumn)) } -func (sb *genericSQLBuilder) AlterTableDropColumn(pTableName string, pColumnName string) string { +func (sb *genericSQLBuilder) AlterTableDropColumn(pTableName, pColumnName string) string { return fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s;", sb.QuoteIdentifier(pTableName), sb.QuoteIdentifier(pColumnName)) } diff --git a/schema/schema.go b/schema/schema.go index 64922ce..4c48360 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -61,7 +61,7 @@ func (s *DBSchema) fetchTableNames() (rTableNames []string, rError error) { return } -func (s *DBSchema) HasColumn(pTableName string, pColumnName string) bool { +func (s *DBSchema) HasColumn(pTableName, pColumnName string) bool { lColumnTypes, lError := s.fetchColumnTypes(pTableName) if lError != nil { return false From efba7c613f2b915a4de67cbe6c8f5b8e37366249 Mon Sep 17 00:00:00 2001 From: Guilherme Raduenz Date: Tue, 27 Jan 2026 16:21:42 -0300 Subject: [PATCH 21/26] db schema name consistence --- database.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/database.go b/database.go index 7edc394..29b3bff 100644 --- a/database.go +++ b/database.go @@ -30,7 +30,7 @@ type Database struct { db *sql.DB connectionString string sqlBuilder SQLBuilder - schema *schema.DBSchema + dbSchema *schema.DBSchema } func OpenDatabase(pDriver DatabaseDriver, pParams ConnectionParams) (*Database, error) { @@ -42,7 +42,7 @@ func OpenDatabase(pDriver DatabaseDriver, pParams ConnectionParams) (*Database, if lError != nil { return nil, fmt.Errorf("sql.Open failed: %w", lError) } - lSchema, lError := pDriver.DBSchema(lDB) + lDBSchema, lError := pDriver.DBSchema(lDB) if lError != nil { return nil, fmt.Errorf("failed to get DB schema: %w", lError) } @@ -52,7 +52,7 @@ func OpenDatabase(pDriver DatabaseDriver, pParams ConnectionParams) (*Database, db: lDB, connectionString: lConnectionString, sqlBuilder: pDriver.SQLBuilder(), - schema: lSchema, + dbSchema: lDBSchema, }, nil } @@ -82,5 +82,5 @@ func (d *Database) SQLBuilder() SQLBuilder { } func (d *Database) DBSchema() *schema.DBSchema { - return d.schema + return d.dbSchema } From ceadebf770196ba2c197e0630402b99272babbd7 Mon Sep 17 00:00:00 2001 From: Guilherme Raduenz Date: Tue, 27 Jan 2026 16:22:01 -0300 Subject: [PATCH 22/26] added simple tests for dialect-specific schema --- dialects/postgresql/schema_test.go | 33 ++++++++++++++++++++++++++++++ dialects/sqlite/schema_test.go | 33 ++++++++++++++++++++++++++++++ dialects/sqlserver/schema_test.go | 33 ++++++++++++++++++++++++++++++ 3 files changed, 99 insertions(+) create mode 100644 dialects/postgresql/schema_test.go create mode 100644 dialects/sqlite/schema_test.go create mode 100644 dialects/sqlserver/schema_test.go diff --git a/dialects/postgresql/schema_test.go b/dialects/postgresql/schema_test.go new file mode 100644 index 0000000..4c44cd1 --- /dev/null +++ b/dialects/postgresql/schema_test.go @@ -0,0 +1,33 @@ +package postgresql_test + +import ( + "testing" + + "github.com/ordershift/ormshift" + "github.com/ordershift/ormshift/dialects/postgresql" + "github.com/ordershift/ormshift/internal/testutils" +) + +func TestHasTable(t *testing.T) { + lDB, lError := ormshift.OpenDatabase(postgresql.Driver(), ormshift.ConnectionParams{InMemory: true}) + if !testutils.AssertNotNilResultAndNilError(t, lDB, lError, "ormshift.OpenDatabase") { + return + } + defer func() { _ = lDB.Close() }() + + if lDB.DBSchema().HasTable("user") { + t.Errorf("Expected HasTable('user') to be false") + } +} + +func TestHasColumn(t *testing.T) { + lDB, lError := ormshift.OpenDatabase(postgresql.Driver(), ormshift.ConnectionParams{InMemory: true}) + if !testutils.AssertNotNilResultAndNilError(t, lDB, lError, "ormshift.OpenDatabase") { + return + } + defer func() { _ = lDB.Close() }() + + if lDB.DBSchema().HasColumn("user", "id") { + t.Errorf("Expected HasColumn('user', 'id') to be false") + } +} diff --git a/dialects/sqlite/schema_test.go b/dialects/sqlite/schema_test.go new file mode 100644 index 0000000..2718ff7 --- /dev/null +++ b/dialects/sqlite/schema_test.go @@ -0,0 +1,33 @@ +package sqlite_test + +import ( + "testing" + + "github.com/ordershift/ormshift" + "github.com/ordershift/ormshift/dialects/sqlite" + "github.com/ordershift/ormshift/internal/testutils" +) + +func TestHasTable(t *testing.T) { + lDB, lError := ormshift.OpenDatabase(sqlite.Driver(), ormshift.ConnectionParams{InMemory: true}) + if !testutils.AssertNotNilResultAndNilError(t, lDB, lError, "ormshift.OpenDatabase") { + return + } + defer func() { _ = lDB.Close() }() + + if lDB.DBSchema().HasTable("user") { + t.Errorf("Expected HasTable('user') to be false") + } +} + +func TestHasColumn(t *testing.T) { + lDB, lError := ormshift.OpenDatabase(sqlite.Driver(), ormshift.ConnectionParams{InMemory: true}) + if !testutils.AssertNotNilResultAndNilError(t, lDB, lError, "ormshift.OpenDatabase") { + return + } + defer func() { _ = lDB.Close() }() + + if lDB.DBSchema().HasColumn("user", "id") { + t.Errorf("Expected HasColumn('user', 'id') to be false") + } +} diff --git a/dialects/sqlserver/schema_test.go b/dialects/sqlserver/schema_test.go new file mode 100644 index 0000000..db45dcf --- /dev/null +++ b/dialects/sqlserver/schema_test.go @@ -0,0 +1,33 @@ +package sqlserver_test + +import ( + "testing" + + "github.com/ordershift/ormshift" + "github.com/ordershift/ormshift/dialects/sqlserver" + "github.com/ordershift/ormshift/internal/testutils" +) + +func TestHasTable(t *testing.T) { + lDB, lError := ormshift.OpenDatabase(sqlserver.Driver(), ormshift.ConnectionParams{InMemory: true}) + if !testutils.AssertNotNilResultAndNilError(t, lDB, lError, "ormshift.OpenDatabase") { + return + } + defer func() { _ = lDB.Close() }() + + if lDB.DBSchema().HasTable("user") { + t.Errorf("Expected HasTable('user') to be false") + } +} + +func TestHasColumn(t *testing.T) { + lDB, lError := ormshift.OpenDatabase(sqlserver.Driver(), ormshift.ConnectionParams{InMemory: true}) + if !testutils.AssertNotNilResultAndNilError(t, lDB, lError, "ormshift.OpenDatabase") { + return + } + defer func() { _ = lDB.Close() }() + + if lDB.DBSchema().HasColumn("user", "id") { + t.Errorf("Expected HasColumn('user', 'id') to be false") + } +} From 776046c7d7168d15e1124d2fee64736816349e1d Mon Sep 17 00:00:00 2001 From: Guilherme Raduenz Date: Tue, 27 Jan 2026 16:24:52 -0300 Subject: [PATCH 23/26] removed potentially unnecessary NOSONAR comment --- migrations/migrator.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/migrations/migrator.go b/migrations/migrator.go index 110750c..7a83934 100644 --- a/migrations/migrator.go +++ b/migrations/migrator.go @@ -187,6 +187,6 @@ func ensureMigrationsTableExists(pDatabase *ormshift.Database, pConfig *Migrator return lError } - _, lError = pDatabase.SQLExecutor().Exec(pDatabase.SQLBuilder().CreateTable(lMigrationsTable)) // NOSONAR go:S2077 - Dynamic SQL is controlled and sanitized internally + _, lError = pDatabase.SQLExecutor().Exec(pDatabase.SQLBuilder().CreateTable(lMigrationsTable)) return lError } From 6c2bb8a05f3a63185cd8a2a1d0e315ce5c924e8a Mon Sep 17 00:00:00 2001 From: Guilherme Raduenz Date: Tue, 27 Jan 2026 16:25:06 -0300 Subject: [PATCH 24/26] added new tests for migrator --- migrations/migrator_test.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/migrations/migrator_test.go b/migrations/migrator_test.go index 519c18f..598b033 100644 --- a/migrations/migrator_test.go +++ b/migrations/migrator_test.go @@ -16,6 +16,18 @@ func TestNewMigratorWhenDatabaseIsNil(t *testing.T) { testutils.AssertErrorMessage(t, "database cannot be nil", lError, "migrations.NewMigrator[database=nil]") } +func TestNewMigratorWhenConfigIsNil(t *testing.T) { + lDB, lError := ormshift.OpenDatabase(sqlite.Driver(), ormshift.ConnectionParams{InMemory: true}) + if !testutils.AssertNotNilResultAndNilError(t, lDB, lError, "ormshift.OpenDatabase") { + return + } + defer func() { _ = lDB.Close() }() + + lMigrator, lError := migrations.NewMigrator(lDB, nil) + testutils.AssertNilResultAndNotNilError(t, lMigrator, lError, "migrations.NewMigrator[config=nil]") + testutils.AssertErrorMessage(t, "migrator config cannot be nil", lError, "migrations.NewMigrator[config=nil]") +} + func TestNewMigratorWhenDatabaseIsInvalid(t *testing.T) { lDriver := testutils.NewFakeDriverInvalidConnectionString(postgresql.Driver()) lDB, lError := ormshift.OpenDatabase(lDriver, ormshift.ConnectionParams{}) From 899b96242bebb77c047964151fe28e95feabd0fb Mon Sep 17 00:00:00 2001 From: Guilherme Raduenz Date: Tue, 27 Jan 2026 16:29:03 -0300 Subject: [PATCH 25/26] added tests for generic sql builder's QuoteIdentifier --- internal/builder_generic_test.go | 8 ++++++++ internal/testutils/fake.go | 4 ++++ 2 files changed, 12 insertions(+) diff --git a/internal/builder_generic_test.go b/internal/builder_generic_test.go index 1ead992..03d1a1c 100644 --- a/internal/builder_generic_test.go +++ b/internal/builder_generic_test.go @@ -156,3 +156,11 @@ func TestColumnDefinition(t *testing.T) { lReturnedSQL := lSQLBuilder.AlterTableAddColumn(lTableName, lColumn) testutils.AssertEqualWithLabel(t, "ALTER TABLE \"test_table\" ADD COLUMN fake;", lReturnedSQL, "SQLBuilder.ColumnDefinition") } + +func TestQuoteIdentifier(t *testing.T) { + lSQLBuilder := internal.NewGenericSQLBuilder(nil, testutils.FakeQuoteIdentifierFunc, nil) + lColumn := schema.NewColumn(schema.NewColumnParams{Name: "column_name", Type: schema.Integer, Size: 0}) + lTableName := "test_table" + lReturnedSQL := lSQLBuilder.AlterTableAddColumn(lTableName, lColumn) + testutils.AssertEqualWithLabel(t, "ALTER TABLE quoted_test_table ADD COLUMN quoted_column_name <>;", lReturnedSQL, "SQLBuilder.QuoteIdentifier") +} diff --git a/internal/testutils/fake.go b/internal/testutils/fake.go index 48207c0..245bfc4 100644 --- a/internal/testutils/fake.go +++ b/internal/testutils/fake.go @@ -165,3 +165,7 @@ func FakeInteroperateSQLCommandWithNamedArgsFunc(command string, namedArgs ...sq func FakeColumnDefinitionFunc(column schema.Column) string { return "fake" } + +func FakeQuoteIdentifierFunc(identifier string) string { + return "quoted_" + identifier +} From 102693a14a98c66a01e69aaee334c299de9b8815 Mon Sep 17 00:00:00 2001 From: Guilherme Raduenz Date: Tue, 27 Jan 2026 16:49:51 -0300 Subject: [PATCH 26/26] fixed and updated documentation --- README.md | 83 +++++++++++++++++++------------------------------------ 1 file changed, 29 insertions(+), 54 deletions(-) diff --git a/README.md b/README.md index 8f39d88..fb46ca2 100644 --- a/README.md +++ b/README.md @@ -120,23 +120,18 @@ defer db.Close() ### Create a table with the SQLBuilder ```go -table, err := ormshift.NewTable("users") -if err != nil { - // handle error -} +table := schema.NewTable("users") -columns := []schema.NewColumnParams{ - {Name: "id", Type: ormshift.Integer, PrimaryKey: true, AutoIncrement: true}, - {Name: "name", Type: ormshift.Varchar, Size: 50, NotNull: false}, -} +err := table.AddColumns( + schema.NewColumnParams{Name: "id", Type: schema.Integer, PrimaryKey: true, AutoIncrement: true}, + schema.NewColumnParams{Name: "name", Type: schema.Varchar, Size: 50, NotNull: false}, +) -for _, col := range columns { - if err := table.AddColumn(col); err != nil { - // handle error - } +if err != nil { + // handle error } -db.SQLExecutor().Exec(db.SQLBuilder().CreateTable(*table)) +db.SQLExecutor().Exec(db.SQLBuilder().CreateTable(table)) ``` ### CRUD with the SQLBuilder and SQLExecutor @@ -209,41 +204,35 @@ type M0001CreateUserTable struct{} func (m M0001CreateUserTable) Up(migrator *migrations.Migrator) error { db := migrator.Database() - table, err := schema.NewTable("user") - if err != nil { - return err - } + table := schema.NewTable("user") if db.DBSchema().HasTable(table.Name()) { // if the table already exists, nothing to do here return nil } - columns := []schema.NewColumnParams{ - {Name: "id", Type: schema.Integer, PrimaryKey: true, AutoIncrement: true}, - {Name: "name", Type: schema.Varchar, Size: 50, NotNull: false}, - {Name: "email", Type: schema.Varchar, Size: 120, NotNull: false}, - {Name: "is_active", Type: schema.Boolean, NotNull: false}, - } - - for _, col := range columns { - if err := table.AddColumn(col); err != nil { - return err - } + err := table.AddColumns( + schema.NewColumnParams{Name: "id", Type: schema.Integer, PrimaryKey: true, AutoIncrement: true}, + schema.NewColumnParams{Name: "name", Type: schema.Varchar, Size: 50, NotNull: false}, + schema.NewColumnParams{Name: "email", Type: schema.Varchar, Size: 120, NotNull: false}, + schema.NewColumnParams{Name: "is_active", Type: schema.Boolean, NotNull: false}, + ) + if err != nil { + return err } - _, err := db.SQLExecutor().Exec(db.SQLBuilder().CreateTable(*table)) + _, err = db.SQLExecutor().Exec(db.SQLBuilder().CreateTable(table)) return err } func (m M0001CreateUserTable) Down(migrator *migrations.Migrator) error { db := migrator.Database() - tableName, _ := schema.NewTableName("user") - if !db.DBSchema().HasTable(*tableName) { + tableName := "user" + if !db.DBSchema().HasTable(tableName) { // if the table already doesn't exist, nothing to do here return nil } - _, err := db.SQLExecutor().Exec(db.SQLBuilder().DropTable(*tableName)) + _, err := db.SQLExecutor().Exec(db.SQLBuilder().DropTable(tableName)) return err } ``` @@ -256,42 +245,28 @@ type M0002AddUpdatedAtColumn struct{} func (m M0002AddUpdatedAtColumn) Up(migrator *migrations.Migrator) error { db := migrator.Database() - tableName, err := schema.NewTableName("user") - if err != nil { - return err - } - - col, err := schema.NewColumn(schema.NewColumnParams{Name: "updated_at", Type: schema.DateTime}) - if err != nil { - return err - } + tableName := "user" + col := schema.NewColumn(schema.NewColumnParams{Name: "updated_at", Type: schema.DateTime}) - if db.DBSchema().HasColumn(*tableName, col.Name()) { + if db.DBSchema().HasColumn(tableName, col.Name()) { // if the column already exists, nothing to do here return nil } - _, err := db.SQLExecutor().Exec(db.SQLBuilder().AlterTableAddColumn(*tableName, *col)) + _, err := db.SQLExecutor().Exec(db.SQLBuilder().AlterTableAddColumn(tableName, col)) return err } func (m M0002AddUpdatedAtColumn) Down(migrator *migrations.Migrator) error { db := migrator.Database() - tableName, err := schema.NewTableName("user") - if err != nil { - return err - } - - colName, err := schema.NewColumnName("updated_at") - if err != nil { - return err - } + tableName := "user" + colName := "updated_at" - if !db.DBSchema().HasColumn(*tableName, *colName) { + if !db.DBSchema().HasColumn(tableName, colName) { // if the column already doesn't exist, nothing to do here return nil } - _, err := db.SQLExecutor().Exec(db.SQLBuilder().AlterTableDropColumn(*tableName, *colName)) + _, err := db.SQLExecutor().Exec(db.SQLBuilder().AlterTableDropColumn(tableName, colName)) return err } ```