diff --git a/README.md b/README.md index 8f39d88..fb46ca2 100644 --- a/README.md +++ b/README.md @@ -120,23 +120,18 @@ defer db.Close() ### Create a table with the SQLBuilder ```go -table, err := ormshift.NewTable("users") -if err != nil { - // handle error -} +table := schema.NewTable("users") -columns := []schema.NewColumnParams{ - {Name: "id", Type: ormshift.Integer, PrimaryKey: true, AutoIncrement: true}, - {Name: "name", Type: ormshift.Varchar, Size: 50, NotNull: false}, -} +err := table.AddColumns( + schema.NewColumnParams{Name: "id", Type: schema.Integer, PrimaryKey: true, AutoIncrement: true}, + schema.NewColumnParams{Name: "name", Type: schema.Varchar, Size: 50, NotNull: false}, +) -for _, col := range columns { - if err := table.AddColumn(col); err != nil { - // handle error - } +if err != nil { + // handle error } -db.SQLExecutor().Exec(db.SQLBuilder().CreateTable(*table)) +db.SQLExecutor().Exec(db.SQLBuilder().CreateTable(table)) ``` ### CRUD with the SQLBuilder and SQLExecutor @@ -209,41 +204,35 @@ type M0001CreateUserTable struct{} func (m M0001CreateUserTable) Up(migrator *migrations.Migrator) error { db := migrator.Database() - table, err := schema.NewTable("user") - if err != nil { - return err - } + table := schema.NewTable("user") if db.DBSchema().HasTable(table.Name()) { // if the table already exists, nothing to do here return nil } - columns := []schema.NewColumnParams{ - {Name: "id", Type: schema.Integer, PrimaryKey: true, AutoIncrement: true}, - {Name: "name", Type: schema.Varchar, Size: 50, NotNull: false}, - {Name: "email", Type: schema.Varchar, Size: 120, NotNull: false}, - {Name: "is_active", Type: schema.Boolean, NotNull: false}, - } - - for _, col := range columns { - if err := table.AddColumn(col); err != nil { - return err - } + err := table.AddColumns( + schema.NewColumnParams{Name: "id", Type: schema.Integer, PrimaryKey: true, AutoIncrement: true}, + schema.NewColumnParams{Name: "name", Type: schema.Varchar, Size: 50, NotNull: false}, + schema.NewColumnParams{Name: "email", Type: schema.Varchar, Size: 120, NotNull: false}, + schema.NewColumnParams{Name: "is_active", Type: schema.Boolean, NotNull: false}, + ) + if err != nil { + return err } - _, err := db.SQLExecutor().Exec(db.SQLBuilder().CreateTable(*table)) + _, err = db.SQLExecutor().Exec(db.SQLBuilder().CreateTable(table)) return err } func (m M0001CreateUserTable) Down(migrator *migrations.Migrator) error { db := migrator.Database() - tableName, _ := schema.NewTableName("user") - if !db.DBSchema().HasTable(*tableName) { + tableName := "user" + if !db.DBSchema().HasTable(tableName) { // if the table already doesn't exist, nothing to do here return nil } - _, err := db.SQLExecutor().Exec(db.SQLBuilder().DropTable(*tableName)) + _, err := db.SQLExecutor().Exec(db.SQLBuilder().DropTable(tableName)) return err } ``` @@ -256,42 +245,28 @@ type M0002AddUpdatedAtColumn struct{} func (m M0002AddUpdatedAtColumn) Up(migrator *migrations.Migrator) error { db := migrator.Database() - tableName, err := schema.NewTableName("user") - if err != nil { - return err - } - - col, err := schema.NewColumn(schema.NewColumnParams{Name: "updated_at", Type: schema.DateTime}) - if err != nil { - return err - } + tableName := "user" + col := schema.NewColumn(schema.NewColumnParams{Name: "updated_at", Type: schema.DateTime}) - if db.DBSchema().HasColumn(*tableName, col.Name()) { + if db.DBSchema().HasColumn(tableName, col.Name()) { // if the column already exists, nothing to do here return nil } - _, err := db.SQLExecutor().Exec(db.SQLBuilder().AlterTableAddColumn(*tableName, *col)) + _, err := db.SQLExecutor().Exec(db.SQLBuilder().AlterTableAddColumn(tableName, col)) return err } func (m M0002AddUpdatedAtColumn) Down(migrator *migrations.Migrator) error { db := migrator.Database() - tableName, err := schema.NewTableName("user") - if err != nil { - return err - } - - colName, err := schema.NewColumnName("updated_at") - if err != nil { - return err - } + tableName := "user" + colName := "updated_at" - if !db.DBSchema().HasColumn(*tableName, *colName) { + if !db.DBSchema().HasColumn(tableName, colName) { // if the column already doesn't exist, nothing to do here return nil } - _, err := db.SQLExecutor().Exec(db.SQLBuilder().AlterTableDropColumn(*tableName, *colName)) + _, err := db.SQLExecutor().Exec(db.SQLBuilder().AlterTableDropColumn(tableName, colName)) return err } ``` diff --git a/builder.go b/builder.go index 0df2939..95ff114 100644 --- a/builder.go +++ b/builder.go @@ -10,9 +10,9 @@ import ( // DDSQLBuilder creates DDL (Data Definition Language) SQL commands for defining schema in DBMS. type DDLSQLBuilder interface { CreateTable(pTable schema.Table) string - DropTable(pTableName schema.TableName) string - AlterTableAddColumn(pTableName schema.TableName, pColumn schema.Column) string - AlterTableDropColumn(pTableName schema.TableName, pColumnName schema.ColumnName) string + DropTable(pTableName string) string + AlterTableAddColumn(pTableName string, pColumn schema.Column) string + AlterTableDropColumn(pTableName, pColumnName string) string ColumnTypeAsString(pColumnType schema.ColumnType) string } @@ -24,9 +24,9 @@ type ColumnsValues map[string]any // lColumnsValues := ColumnsValues{"id": 5, "sku": "ZTX-9000", "is_simple": true} // lNamedArgs := lColumnsValues.ToNamedArgs() // //lNamedArgs == []sql.NamedArg{{Name: "id", Value: 5},{Name: "is_simple", Value: true},{Name: "sku", Value: "ZTX-9000"}} -func (cv ColumnsValues) ToNamedArgs() []sql.NamedArg { +func (cv *ColumnsValues) ToNamedArgs() []sql.NamedArg { lNamedArgs := []sql.NamedArg{} - for c, v := range cv { + for c, v := range *cv { lNamedArgs = append(lNamedArgs, sql.Named(c, v)) } slices.SortFunc(lNamedArgs, func(a, b sql.NamedArg) int { @@ -39,9 +39,9 @@ func (cv ColumnsValues) ToNamedArgs() []sql.NamedArg { } // ToColumns returns the column names from ColumnsValues as a string array ordered by name, e.g.: -func (cv ColumnsValues) ToColumns() []string { +func (cv *ColumnsValues) ToColumns() []string { lColumns := []string{} - for c := range cv { + for c := range *cv { lColumns = append(lColumns, c) } slices.Sort(lColumns) @@ -91,4 +91,5 @@ type DMLSQLBuilder interface { type SQLBuilder interface { DDLSQLBuilder DMLSQLBuilder + QuoteIdentifier(pIdentifier string) string } diff --git a/database.go b/database.go index 059caff..29b3bff 100644 --- a/database.go +++ b/database.go @@ -28,10 +28,9 @@ type DatabaseDriver interface { type Database struct { driver DatabaseDriver db *sql.DB - executor SQLExecutor connectionString string sqlBuilder SQLBuilder - schema schema.DBSchema + dbSchema *schema.DBSchema } func OpenDatabase(pDriver DatabaseDriver, pParams ConnectionParams) (*Database, error) { @@ -43,17 +42,17 @@ func OpenDatabase(pDriver DatabaseDriver, pParams ConnectionParams) (*Database, if lError != nil { return nil, fmt.Errorf("sql.Open failed: %w", lError) } - lSchema, lError := pDriver.DBSchema(lDB) + lDBSchema, lError := pDriver.DBSchema(lDB) if lError != nil { return nil, fmt.Errorf("failed to get DB schema: %w", lError) } + return &Database{ driver: pDriver, db: lDB, - executor: lDB, connectionString: lConnectionString, sqlBuilder: pDriver.SQLBuilder(), - schema: *lSchema, + dbSchema: lDBSchema, }, nil } @@ -66,7 +65,7 @@ func (d *Database) DB() *sql.DB { } func (d *Database) SQLExecutor() SQLExecutor { - return d.executor + return d.db } func (d *Database) DriverName() string { @@ -82,6 +81,6 @@ func (d *Database) SQLBuilder() SQLBuilder { return d.sqlBuilder } -func (d *Database) DBSchema() schema.DBSchema { - return d.schema +func (d *Database) DBSchema() *schema.DBSchema { + return d.dbSchema } diff --git a/database_test.go b/database_test.go index 8c98070..852b163 100644 --- a/database_test.go +++ b/database_test.go @@ -124,7 +124,7 @@ func TestDriverConnectionString(t *testing.T) { func TestDriverSQLBuilder(t *testing.T) { lDriver := testutils.NewFakeDriver(sqlite.Driver()) lSQLBuilder := lDriver.SQLBuilder() - testutils.AssertEqualWithLabel(t, "sqliteBuilder", reflect.TypeOf(lSQLBuilder).Name(), "FakeDriver.SQLBuilder") + testutils.AssertEqualWithLabel(t, "sqliteBuilder", reflect.TypeOf(lSQLBuilder).Elem().Name(), "FakeDriver.SQLBuilder") } func TestDriverDBSchema(t *testing.T) { diff --git a/dialects/postgresql/builder.go b/dialects/postgresql/builder.go index f8c1826..ba0fd99 100644 --- a/dialects/postgresql/builder.go +++ b/dialects/postgresql/builder.go @@ -16,28 +16,28 @@ type postgresqlBuilder struct { } func newPostgreSQLBuilder() ormshift.SQLBuilder { - lBuilder := postgresqlBuilder{} - lBuilder.generic = internal.NewGenericSQLBuilder(lBuilder.columnDefinition, lBuilder.InteroperateSQLCommandWithNamedArgs) - return lBuilder + sb := postgresqlBuilder{} + sb.generic = internal.NewGenericSQLBuilder(sb.columnDefinition, nil, sb.InteroperateSQLCommandWithNamedArgs) + return &sb } -func (sb postgresqlBuilder) CreateTable(pTable schema.Table) string { +func (sb *postgresqlBuilder) CreateTable(pTable schema.Table) string { return sb.generic.CreateTable(pTable) } -func (sb postgresqlBuilder) DropTable(pTableName schema.TableName) string { +func (sb *postgresqlBuilder) DropTable(pTableName string) string { return sb.generic.DropTable(pTableName) } -func (sb postgresqlBuilder) AlterTableAddColumn(pTableName schema.TableName, pColumn schema.Column) string { +func (sb *postgresqlBuilder) AlterTableAddColumn(pTableName string, pColumn schema.Column) string { return sb.generic.AlterTableAddColumn(pTableName, pColumn) } -func (sb postgresqlBuilder) AlterTableDropColumn(pTableName schema.TableName, pColumnName schema.ColumnName) string { +func (sb *postgresqlBuilder) AlterTableDropColumn(pTableName, pColumnName string) string { return sb.generic.AlterTableDropColumn(pTableName, pColumnName) } -func (sb postgresqlBuilder) ColumnTypeAsString(pColumnType schema.ColumnType) string { +func (sb *postgresqlBuilder) ColumnTypeAsString(pColumnType schema.ColumnType) string { switch pColumnType { case schema.Varchar: return "VARCHAR" @@ -58,8 +58,8 @@ func (sb postgresqlBuilder) ColumnTypeAsString(pColumnType schema.ColumnType) st } } -func (sb postgresqlBuilder) columnDefinition(pColumn schema.Column) string { - lColumnDef := pColumn.Name().String() +func (sb *postgresqlBuilder) columnDefinition(pColumn schema.Column) string { + lColumnDef := sb.QuoteIdentifier(pColumn.Name()) if pColumn.AutoIncrement() { lColumnDef += " BIGSERIAL" } else { @@ -75,43 +75,47 @@ func (sb postgresqlBuilder) columnDefinition(pColumn schema.Column) string { return lColumnDef } -func (sb postgresqlBuilder) Insert(pTableName string, pColumns []string) string { +func (sb *postgresqlBuilder) Insert(pTableName string, pColumns []string) string { return sb.generic.Insert(pTableName, pColumns) } -func (sb postgresqlBuilder) InsertWithValues(pTableName string, pColumnsValues ormshift.ColumnsValues) (string, []any) { +func (sb *postgresqlBuilder) InsertWithValues(pTableName string, pColumnsValues ormshift.ColumnsValues) (string, []any) { return sb.generic.InsertWithValues(pTableName, pColumnsValues) } -func (sb postgresqlBuilder) Update(pTableName string, pColumns, pColumnsWhere []string) string { +func (sb *postgresqlBuilder) Update(pTableName string, pColumns, pColumnsWhere []string) string { return sb.generic.Update(pTableName, pColumns, pColumnsWhere) } -func (sb postgresqlBuilder) UpdateWithValues(pTableName string, pColumns, pColumnsWhere []string, pValues ormshift.ColumnsValues) (string, []any) { +func (sb *postgresqlBuilder) UpdateWithValues(pTableName string, pColumns, pColumnsWhere []string, pValues ormshift.ColumnsValues) (string, []any) { return sb.generic.UpdateWithValues(pTableName, pColumns, pColumnsWhere, pValues) } -func (sb postgresqlBuilder) Delete(pTableName string, pColumnsWhere []string) string { +func (sb *postgresqlBuilder) Delete(pTableName string, pColumnsWhere []string) string { return sb.generic.Delete(pTableName, pColumnsWhere) } -func (sb postgresqlBuilder) DeleteWithValues(pTableName string, pWhereColumnsValues ormshift.ColumnsValues) (string, []any) { +func (sb *postgresqlBuilder) DeleteWithValues(pTableName string, pWhereColumnsValues ormshift.ColumnsValues) (string, []any) { return sb.generic.DeleteWithValues(pTableName, pWhereColumnsValues) } -func (sb postgresqlBuilder) Select(pTableName string, pColumns, pColumnsWhere []string) string { +func (sb *postgresqlBuilder) Select(pTableName string, pColumns, pColumnsWhere []string) string { return sb.generic.Select(pTableName, pColumns, pColumnsWhere) } -func (sb postgresqlBuilder) SelectWithValues(pTableName string, pColumns []string, pWhereColumnsValues ormshift.ColumnsValues) (string, []any) { +func (sb *postgresqlBuilder) SelectWithValues(pTableName string, pColumns []string, pWhereColumnsValues ormshift.ColumnsValues) (string, []any) { return sb.generic.SelectWithValues(pTableName, pColumns, pWhereColumnsValues) } -func (sb postgresqlBuilder) SelectWithPagination(pSQLSelectCommand string, pRowsPerPage, pPageNumber uint) string { +func (sb *postgresqlBuilder) SelectWithPagination(pSQLSelectCommand string, pRowsPerPage, pPageNumber uint) string { return sb.generic.SelectWithPagination(pSQLSelectCommand, pRowsPerPage, pPageNumber) } -func (sb postgresqlBuilder) InteroperateSQLCommandWithNamedArgs(pSQLCommand string, pNamedArgs ...sql.NamedArg) (string, []any) { +func (sb *postgresqlBuilder) QuoteIdentifier(pIdentifier string) string { + return sb.generic.QuoteIdentifier(pIdentifier) +} + +func (sb *postgresqlBuilder) InteroperateSQLCommandWithNamedArgs(pSQLCommand string, pNamedArgs ...sql.NamedArg) (string, []any) { lSQLCommand := pSQLCommand lArgs := []any{} lIndexes := map[string]int{} diff --git a/dialects/postgresql/builder_test.go b/dialects/postgresql/builder_test.go index 698acff..62a838f 100644 --- a/dialects/postgresql/builder_test.go +++ b/dialects/postgresql/builder_test.go @@ -42,15 +42,15 @@ func TestCreateTable(t *testing.T) { lSQLBuilder := postgresql.Driver().SQLBuilder() lUserTable := testutils.FakeUserTable(t) - lExpectedSQL := "CREATE TABLE user (id BIGSERIAL NOT NULL,email VARCHAR(80) NOT NULL,name VARCHAR(50) NOT NULL," + - "password_hash VARCHAR(256),active SMALLINT,created_at TIMESTAMP(6),user_master BIGINT,master_user_id BIGINT," + - "licence_price NUMERIC(17,2),relevance DOUBLE PRECISION,photo BYTEA,any VARCHAR,PRIMARY KEY (id,email));" - lReturnedSQL := lSQLBuilder.CreateTable(*lUserTable) + lExpectedSQL := "CREATE TABLE \"user\" (\"id\" BIGSERIAL NOT NULL,\"email\" VARCHAR(80) NOT NULL,\"name\" VARCHAR(50) NOT NULL," + + "\"password_hash\" VARCHAR(256),\"active\" SMALLINT,\"created_at\" TIMESTAMP(6),\"user_master\" BIGINT,\"master_user_id\" BIGINT," + + "\"licence_price\" NUMERIC(17,2),\"relevance\" DOUBLE PRECISION,\"photo\" BYTEA,\"any\" VARCHAR,PRIMARY KEY (\"id\",\"email\"));" + lReturnedSQL := lSQLBuilder.CreateTable(lUserTable) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.CreateTable") lProductAttributeTable := testutils.FakeProductAttributeTable(t) - lExpectedSQL = "CREATE TABLE product_attribute (product_id BIGINT NOT NULL,attribute_id BIGINT NOT NULL,value VARCHAR(75),position BIGINT,PRIMARY KEY (product_id,attribute_id));" - lReturnedSQL = lSQLBuilder.CreateTable(*lProductAttributeTable) + lExpectedSQL = "CREATE TABLE \"product_attribute\" (\"product_id\" BIGINT NOT NULL,\"attribute_id\" BIGINT NOT NULL,\"value\" VARCHAR(75),\"position\" BIGINT,PRIMARY KEY (\"product_id\",\"attribute_id\"));" + lReturnedSQL = lSQLBuilder.CreateTable(lProductAttributeTable) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.CreateTable") } @@ -58,8 +58,8 @@ func TestDropTable(t *testing.T) { lSQLBuilder := postgresql.Driver().SQLBuilder() lUserTableName := testutils.FakeUserTableName(t) - lExpectedSQL := "DROP TABLE user;" - lReturnedSQL := lSQLBuilder.DropTable(*lUserTableName) + lExpectedSQL := "DROP TABLE \"user\";" + lReturnedSQL := lSQLBuilder.DropTable(lUserTableName) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.DropTable") } @@ -68,8 +68,8 @@ func TestAlterTableAddColumn(t *testing.T) { lUserTableName := testutils.FakeUserTableName(t) lUpdatedAtColumn := testutils.FakeUpdatedAtColumn(t) - lExpectedSQL := "ALTER TABLE user ADD COLUMN updated_at TIMESTAMP(6);" - lReturnedSQL := lSQLBuilder.AlterTableAddColumn(*lUserTableName, *lUpdatedAtColumn) + lExpectedSQL := "ALTER TABLE \"user\" ADD COLUMN \"updated_at\" TIMESTAMP(6);" + lReturnedSQL := lSQLBuilder.AlterTableAddColumn(lUserTableName, lUpdatedAtColumn) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.AlterTableAddColumn") } @@ -78,8 +78,8 @@ func TestAlterTableDropColumn(t *testing.T) { lUserTableName := testutils.FakeUserTableName(t) lUpdatedAtColumnName := testutils.FakeUpdatedAtColumnName(t) - lExpectedSQL := "ALTER TABLE user DROP COLUMN updated_at;" - lReturnedSQL := lSQLBuilder.AlterTableDropColumn(*lUserTableName, *lUpdatedAtColumnName) + lExpectedSQL := "ALTER TABLE \"user\" DROP COLUMN \"updated_at\";" + lReturnedSQL := lSQLBuilder.AlterTableDropColumn(lUserTableName, lUpdatedAtColumnName) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.AlterTableDropColumn") } @@ -87,7 +87,7 @@ func TestInsert(t *testing.T) { lSQLBuilder := postgresql.Driver().SQLBuilder() lReturnedSQL := lSQLBuilder.Insert("product", []string{"id", "sku", "name", "description"}) - lExpectedSQL := "insert into product (id,sku,name,description) values (@id,@sku,@name,@description)" + lExpectedSQL := "insert into \"product\" (\"id\",\"sku\",\"name\",\"description\") values (@id,@sku,@name,@description)" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.Insert") } @@ -95,7 +95,7 @@ func TestInsertWithValues(t *testing.T) { lSQLBuilder := postgresql.Driver().SQLBuilder() lReturnedSQL, lReturnedValues := lSQLBuilder.InsertWithValues("product", ormshift.ColumnsValues{"id": 1, "sku": "1.005.12.9", "name": "Trufa Sabor Amarula 30g Cacaushow"}) - lExpectedSQL := "insert into product (id,name,sku) values ($1,$2,$3)" + lExpectedSQL := "insert into \"product\" (\"id\",\"name\",\"sku\") values ($1,$2,$3)" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.InsertWithValues.SQL") testutils.AssertEqualWithLabel(t, 3, len(lReturnedValues), "SQLBuilder.InsertWithValues.Values") testutils.AssertEqualWithLabel(t, 1, lReturnedValues[0], "SQLBuilder.InsertWithValues.Values[0]") @@ -107,7 +107,7 @@ func TestUpdate(t *testing.T) { lSQLBuilder := postgresql.Driver().SQLBuilder() lReturnedSQL := lSQLBuilder.Update("product", []string{"sku", "name", "description"}, []string{"id"}) - lExpectedSQL := "update product set sku = @sku,name = @name,description = @description where id = @id" + lExpectedSQL := "update \"product\" set \"sku\" = @sku,\"name\" = @name,\"description\" = @description where \"id\" = @id" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.Update") } @@ -115,7 +115,7 @@ func TestUpdateWithValues(t *testing.T) { lSQLBuilder := postgresql.Driver().SQLBuilder() lReturnedSQL, lReturnedValues := lSQLBuilder.UpdateWithValues("product", []string{"sku", "name"}, []string{"id"}, ormshift.ColumnsValues{"id": 1, "sku": "1.005.12.5", "name": "Trufa Sabor Amarula 18g Cacaushow"}) - lExpectedSQL := "update product set sku = $3,name = $2 where id = $1" + lExpectedSQL := "update \"product\" set \"sku\" = $3,\"name\" = $2 where \"id\" = $1" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.UpdateWithValues.SQL") testutils.AssertEqualWithLabel(t, 3, len(lReturnedValues), "SQLBuilder.UpdateWithValues.Values") testutils.AssertEqualWithLabel(t, 1, lReturnedValues[0], "SQLBuilder.UpdateWithValues.Values[0]") @@ -127,7 +127,7 @@ func TestDelete(t *testing.T) { lSQLBuilder := postgresql.Driver().SQLBuilder() lReturnedSQL := lSQLBuilder.Delete("product", []string{"id"}) - lExpectedSQL := "delete from product where id = @id" + lExpectedSQL := "delete from \"product\" where \"id\" = @id" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.Delete") } @@ -135,7 +135,7 @@ func TestDeleteWithValues(t *testing.T) { lSQLBuilder := postgresql.Driver().SQLBuilder() lReturnedSQL, lReturnedValues := lSQLBuilder.DeleteWithValues("product", ormshift.ColumnsValues{"id": 1}) - lExpectedSQL := "delete from product where id = $1" + lExpectedSQL := "delete from \"product\" where \"id\" = $1" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.DeleteWithValues.SQL") testutils.AssertEqualWithLabel(t, 1, len(lReturnedValues), "SQLBuilder.DeleteWithValues.Values") testutils.AssertEqualWithLabel(t, 1, lReturnedValues[0], "SQLBuilder.DeleteWithValues.Values[0]") @@ -145,7 +145,7 @@ func TestSelect(t *testing.T) { lSQLBuilder := postgresql.Driver().SQLBuilder() lReturnedSQL := lSQLBuilder.Select("product", []string{"id", "name", "description"}, []string{"sku", "active"}) - lExpectedSQL := "select id,name,description from product where sku = @sku and active = @active" + lExpectedSQL := "select \"id\",\"name\",\"description\" from \"product\" where \"sku\" = @sku and \"active\" = @active" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.Select") } @@ -153,7 +153,7 @@ func TestSelectWithValues(t *testing.T) { lSQLBuilder := postgresql.Driver().SQLBuilder() lReturnedSQL, lReturnedValues := lSQLBuilder.SelectWithValues("product", []string{"id", "sku", "name", "description"}, ormshift.ColumnsValues{"category_id": 1, "active": true}) - lExpectedSQL := "select id,sku,name,description from product where active = $1 and category_id = $2" + lExpectedSQL := "select \"id\",\"sku\",\"name\",\"description\" from \"product\" where \"active\" = $1 and \"category_id\" = $2" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.SelectWithValues.SQL") testutils.AssertEqualWithLabel(t, 2, len(lReturnedValues), "SQLBuilder.SelectWithValues.Values") testutils.AssertEqualWithLabel(t, 1, lReturnedValues[0], "SQLBuilder.SelectWithValues.Values[0]") diff --git a/dialects/postgresql/driver.go b/dialects/postgresql/driver.go index 4962e14..1123ac7 100644 --- a/dialects/postgresql/driver.go +++ b/dialects/postgresql/driver.go @@ -11,17 +11,21 @@ import ( "github.com/ordershift/ormshift/schema" ) -type postgresqlDriver struct{} +type postgresqlDriver struct { + sqlBuilder ormshift.SQLBuilder +} func Driver() ormshift.DatabaseDriver { - return postgresqlDriver{} + return &postgresqlDriver{ + sqlBuilder: newPostgreSQLBuilder(), + } } -func (d postgresqlDriver) Name() string { +func (d *postgresqlDriver) Name() string { return "postgres" } -func (d postgresqlDriver) ConnectionString(pParams ormshift.ConnectionParams) string { +func (d *postgresqlDriver) ConnectionString(pParams ormshift.ConnectionParams) string { lHost := pParams.Host if lHost == "" { lHost = "localhost" @@ -33,10 +37,10 @@ func (d postgresqlDriver) ConnectionString(pParams ormshift.ConnectionParams) st return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", lHost, lPort, pParams.User, pParams.Password, pParams.Database) } -func (d postgresqlDriver) SQLBuilder() ormshift.SQLBuilder { - return newPostgreSQLBuilder() +func (d *postgresqlDriver) SQLBuilder() ormshift.SQLBuilder { + return d.sqlBuilder } -func (d postgresqlDriver) DBSchema(pDB *sql.DB) (*schema.DBSchema, error) { - return schema.NewDBSchema(pDB, tableNamesQuery) +func (d *postgresqlDriver) DBSchema(pDB *sql.DB) (*schema.DBSchema, error) { + return schema.NewDBSchema(pDB, tableNamesQuery, columnTypesQueryFunc(d.sqlBuilder)) } diff --git a/dialects/postgresql/schema.go b/dialects/postgresql/schema.go index 00e5d79..722d04c 100644 --- a/dialects/postgresql/schema.go +++ b/dialects/postgresql/schema.go @@ -1,5 +1,11 @@ package postgresql +import ( + "fmt" + + "github.com/ordershift/ormshift" +) + const tableNamesQuery = ` SELECT table_name @@ -11,3 +17,9 @@ const tableNamesQuery = ` ORDER BY table_name ` + +func columnTypesQueryFunc(pSQLBuilder ormshift.SQLBuilder) func(string) string { + return func(pTableName string) string { + return fmt.Sprintf("SELECT * FROM %s WHERE 1=0", pSQLBuilder.QuoteIdentifier(pTableName)) + } +} diff --git a/dialects/postgresql/schema_test.go b/dialects/postgresql/schema_test.go new file mode 100644 index 0000000..4c44cd1 --- /dev/null +++ b/dialects/postgresql/schema_test.go @@ -0,0 +1,33 @@ +package postgresql_test + +import ( + "testing" + + "github.com/ordershift/ormshift" + "github.com/ordershift/ormshift/dialects/postgresql" + "github.com/ordershift/ormshift/internal/testutils" +) + +func TestHasTable(t *testing.T) { + lDB, lError := ormshift.OpenDatabase(postgresql.Driver(), ormshift.ConnectionParams{InMemory: true}) + if !testutils.AssertNotNilResultAndNilError(t, lDB, lError, "ormshift.OpenDatabase") { + return + } + defer func() { _ = lDB.Close() }() + + if lDB.DBSchema().HasTable("user") { + t.Errorf("Expected HasTable('user') to be false") + } +} + +func TestHasColumn(t *testing.T) { + lDB, lError := ormshift.OpenDatabase(postgresql.Driver(), ormshift.ConnectionParams{InMemory: true}) + if !testutils.AssertNotNilResultAndNilError(t, lDB, lError, "ormshift.OpenDatabase") { + return + } + defer func() { _ = lDB.Close() }() + + if lDB.DBSchema().HasColumn("user", "id") { + t.Errorf("Expected HasColumn('user', 'id') to be false") + } +} diff --git a/dialects/sqlite/builder.go b/dialects/sqlite/builder.go index 07cc112..d158670 100644 --- a/dialects/sqlite/builder.go +++ b/dialects/sqlite/builder.go @@ -14,12 +14,12 @@ type sqliteBuilder struct { } func newSQLiteBuilder() ormshift.SQLBuilder { - lBuilder := sqliteBuilder{} - lBuilder.generic = internal.NewGenericSQLBuilder(lBuilder.columnDefinition, nil) - return lBuilder + sb := sqliteBuilder{} + sb.generic = internal.NewGenericSQLBuilder(sb.columnDefinition, nil, nil) + return &sb } -func (sb sqliteBuilder) CreateTable(pTable schema.Table) string { +func (sb *sqliteBuilder) CreateTable(pTable schema.Table) string { lColumns := "" lPKColumns := "" lHasAutoIncrementColumn := false @@ -33,7 +33,7 @@ func (sb sqliteBuilder) CreateTable(pTable schema.Table) string { if lPKColumns != "" { lPKColumns += "," } - lPKColumns += lColumn.Name().String() + lPKColumns += sb.QuoteIdentifier(lColumn.Name()) } if !lHasAutoIncrementColumn { @@ -45,24 +45,25 @@ func (sb sqliteBuilder) CreateTable(pTable schema.Table) string { if lColumns != "" { lColumns += "," } - lColumns += fmt.Sprintf("CONSTRAINT PK_%s PRIMARY KEY (%s)", pTable.Name().String(), lPKColumns) + lPKConstraintName := sb.QuoteIdentifier("PK_" + pTable.Name()) + lColumns += fmt.Sprintf("CONSTRAINT %s PRIMARY KEY (%s)", lPKConstraintName, lPKColumns) } - return fmt.Sprintf("CREATE TABLE %s (%s);", pTable.Name().String(), lColumns) + return fmt.Sprintf("CREATE TABLE %s (%s);", sb.QuoteIdentifier(pTable.Name()), lColumns) } -func (sb sqliteBuilder) DropTable(pTableName schema.TableName) string { +func (sb *sqliteBuilder) DropTable(pTableName string) string { return sb.generic.DropTable(pTableName) } -func (sb sqliteBuilder) AlterTableAddColumn(pTableName schema.TableName, pColumn schema.Column) string { +func (sb *sqliteBuilder) AlterTableAddColumn(pTableName string, pColumn schema.Column) string { return sb.generic.AlterTableAddColumn(pTableName, pColumn) } -func (sb sqliteBuilder) AlterTableDropColumn(pTableName schema.TableName, pColumnName schema.ColumnName) string { +func (sb *sqliteBuilder) AlterTableDropColumn(pTableName, pColumnName string) string { return sb.generic.AlterTableDropColumn(pTableName, pColumnName) } -func (sb sqliteBuilder) ColumnTypeAsString(pColumnType schema.ColumnType) string { +func (sb *sqliteBuilder) ColumnTypeAsString(pColumnType schema.ColumnType) string { switch pColumnType { case schema.Varchar: return "TEXT" @@ -83,8 +84,8 @@ func (sb sqliteBuilder) ColumnTypeAsString(pColumnType schema.ColumnType) string } } -func (sb sqliteBuilder) columnDefinition(pColumn schema.Column) string { - lColumnDef := fmt.Sprintf("%s %s", pColumn.Name().String(), sb.ColumnTypeAsString(pColumn.Type())) +func (sb *sqliteBuilder) columnDefinition(pColumn schema.Column) string { + lColumnDef := fmt.Sprintf("%s %s", sb.QuoteIdentifier(pColumn.Name()), sb.ColumnTypeAsString(pColumn.Type())) if pColumn.NotNull() { lColumnDef += " NOT NULL" } @@ -94,42 +95,46 @@ func (sb sqliteBuilder) columnDefinition(pColumn schema.Column) string { return lColumnDef } -func (sb sqliteBuilder) Insert(pTableName string, pColumns []string) string { +func (sb *sqliteBuilder) Insert(pTableName string, pColumns []string) string { return sb.generic.Insert(pTableName, pColumns) } -func (sb sqliteBuilder) InsertWithValues(pTableName string, pColumnsValues ormshift.ColumnsValues) (string, []any) { +func (sb *sqliteBuilder) InsertWithValues(pTableName string, pColumnsValues ormshift.ColumnsValues) (string, []any) { return sb.generic.InsertWithValues(pTableName, pColumnsValues) } -func (sb sqliteBuilder) Update(pTableName string, pColumns, pColumnsWhere []string) string { +func (sb *sqliteBuilder) Update(pTableName string, pColumns, pColumnsWhere []string) string { return sb.generic.Update(pTableName, pColumns, pColumnsWhere) } -func (sb sqliteBuilder) UpdateWithValues(pTableName string, pColumns, pColumnsWhere []string, pValues ormshift.ColumnsValues) (string, []any) { +func (sb *sqliteBuilder) UpdateWithValues(pTableName string, pColumns, pColumnsWhere []string, pValues ormshift.ColumnsValues) (string, []any) { return sb.generic.UpdateWithValues(pTableName, pColumns, pColumnsWhere, pValues) } -func (sb sqliteBuilder) Delete(pTableName string, pColumnsWhere []string) string { +func (sb *sqliteBuilder) Delete(pTableName string, pColumnsWhere []string) string { return sb.generic.Delete(pTableName, pColumnsWhere) } -func (sb sqliteBuilder) DeleteWithValues(pTableName string, pWhereColumnsValues ormshift.ColumnsValues) (string, []any) { +func (sb *sqliteBuilder) DeleteWithValues(pTableName string, pWhereColumnsValues ormshift.ColumnsValues) (string, []any) { return sb.generic.DeleteWithValues(pTableName, pWhereColumnsValues) } -func (sb sqliteBuilder) Select(pTableName string, pColumns, pColumnsWhere []string) string { +func (sb *sqliteBuilder) Select(pTableName string, pColumns, pColumnsWhere []string) string { return sb.generic.Select(pTableName, pColumns, pColumnsWhere) } -func (sb sqliteBuilder) SelectWithValues(pTableName string, pColumns []string, pWhereColumnsValues ormshift.ColumnsValues) (string, []any) { +func (sb *sqliteBuilder) SelectWithValues(pTableName string, pColumns []string, pWhereColumnsValues ormshift.ColumnsValues) (string, []any) { return sb.generic.SelectWithValues(pTableName, pColumns, pWhereColumnsValues) } -func (sb sqliteBuilder) SelectWithPagination(pSQLSelectCommand string, pRowsPerPage, pPageNumber uint) string { +func (sb *sqliteBuilder) SelectWithPagination(pSQLSelectCommand string, pRowsPerPage, pPageNumber uint) string { return sb.generic.SelectWithPagination(pSQLSelectCommand, pRowsPerPage, pPageNumber) } -func (sb sqliteBuilder) InteroperateSQLCommandWithNamedArgs(pSQLCommand string, pNamedArgs ...sql.NamedArg) (string, []any) { +func (sb *sqliteBuilder) QuoteIdentifier(pIdentifier string) string { + return sb.generic.QuoteIdentifier(pIdentifier) +} + +func (sb *sqliteBuilder) InteroperateSQLCommandWithNamedArgs(pSQLCommand string, pNamedArgs ...sql.NamedArg) (string, []any) { return sb.generic.InteroperateSQLCommandWithNamedArgs(pSQLCommand, pNamedArgs...) } diff --git a/dialects/sqlite/builder_test.go b/dialects/sqlite/builder_test.go index 95e8ab9..fb386c5 100644 --- a/dialects/sqlite/builder_test.go +++ b/dialects/sqlite/builder_test.go @@ -22,14 +22,14 @@ func TestCreateTable(t *testing.T) { lSQLBuilder := sqlite.Driver().SQLBuilder() lUserTable := testutils.FakeUserTable(t) - lExpectedSQL := "CREATE TABLE user (id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,email TEXT NOT NULL,name TEXT NOT NULL," + - "password_hash TEXT,active INTEGER,created_at DATETIME,user_master INTEGER,master_user_id INTEGER,licence_price REAL,relevance REAL,photo BLOB,any TEXT);" - lReturnedSQL := lSQLBuilder.CreateTable(*lUserTable) + lExpectedSQL := "CREATE TABLE \"user\" (\"id\" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,\"email\" TEXT NOT NULL,\"name\" TEXT NOT NULL," + + "\"password_hash\" TEXT,\"active\" INTEGER,\"created_at\" DATETIME,\"user_master\" INTEGER,\"master_user_id\" INTEGER,\"licence_price\" REAL,\"relevance\" REAL,\"photo\" BLOB,\"any\" TEXT);" + lReturnedSQL := lSQLBuilder.CreateTable(lUserTable) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.CreateTable") lProductAttributeTable := testutils.FakeProductAttributeTable(t) - lExpectedSQL = "CREATE TABLE product_attribute (product_id INTEGER NOT NULL,attribute_id INTEGER NOT NULL,value TEXT,position INTEGER,CONSTRAINT PK_product_attribute PRIMARY KEY (product_id,attribute_id));" - lReturnedSQL = lSQLBuilder.CreateTable(*lProductAttributeTable) + lExpectedSQL = "CREATE TABLE \"product_attribute\" (\"product_id\" INTEGER NOT NULL,\"attribute_id\" INTEGER NOT NULL,\"value\" TEXT,\"position\" INTEGER,CONSTRAINT \"PK_product_attribute\" PRIMARY KEY (\"product_id\",\"attribute_id\"));" + lReturnedSQL = lSQLBuilder.CreateTable(lProductAttributeTable) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.CreateTable") } @@ -37,8 +37,8 @@ func TestDropTable(t *testing.T) { lSQLBuilder := sqlite.Driver().SQLBuilder() lUserTableName := testutils.FakeUserTableName(t) - lExpectedSQL := "DROP TABLE user;" - lReturnedSQL := lSQLBuilder.DropTable(*lUserTableName) + lExpectedSQL := "DROP TABLE \"user\";" + lReturnedSQL := lSQLBuilder.DropTable(lUserTableName) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.DropTable") } @@ -47,8 +47,8 @@ func TestAlterTableAddColumn(t *testing.T) { lUserTableName := testutils.FakeUserTableName(t) lUpdatedAtColumn := testutils.FakeUpdatedAtColumn(t) - lExpectedSQL := "ALTER TABLE user ADD COLUMN updated_at DATETIME;" - lReturnedSQL := lSQLBuilder.AlterTableAddColumn(*lUserTableName, *lUpdatedAtColumn) + lExpectedSQL := "ALTER TABLE \"user\" ADD COLUMN \"updated_at\" DATETIME;" + lReturnedSQL := lSQLBuilder.AlterTableAddColumn(lUserTableName, lUpdatedAtColumn) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.AlterTableAddColumn") } @@ -57,8 +57,8 @@ func TestAlterTableDropColumn(t *testing.T) { lUserTableName := testutils.FakeUserTableName(t) lUpdatedAtColumnName := testutils.FakeUpdatedAtColumnName(t) - lExpectedSQL := "ALTER TABLE user DROP COLUMN updated_at;" - lReturnedSQL := lSQLBuilder.AlterTableDropColumn(*lUserTableName, *lUpdatedAtColumnName) + lExpectedSQL := "ALTER TABLE \"user\" DROP COLUMN \"updated_at\";" + lReturnedSQL := lSQLBuilder.AlterTableDropColumn(lUserTableName, lUpdatedAtColumnName) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.AlterTableDropColumn") } @@ -66,7 +66,7 @@ func TestInsert(t *testing.T) { lSQLBuilder := sqlite.Driver().SQLBuilder() lReturnedSQL := lSQLBuilder.Insert("product", []string{"id", "sku", "name", "description"}) - lExpectedSQL := "insert into product (id,sku,name,description) values (@id,@sku,@name,@description)" + lExpectedSQL := "insert into \"product\" (\"id\",\"sku\",\"name\",\"description\") values (@id,@sku,@name,@description)" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.Insert") } @@ -74,7 +74,7 @@ func TestInsertWithValues(t *testing.T) { lSQLBuilder := sqlite.Driver().SQLBuilder() lReturnedSQL, lReturnedValues := lSQLBuilder.InsertWithValues("product", ormshift.ColumnsValues{"id": 1, "sku": "1.005.12.9", "name": "Trufa Sabor Amarula 30g Cacaushow"}) - lExpectedSQL := "insert into product (id,name,sku) values (@id,@name,@sku)" + lExpectedSQL := "insert into \"product\" (\"id\",\"name\",\"sku\") values (@id,@name,@sku)" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.InsertWithValues.SQL") testutils.AssertEqualWithLabel(t, 3, len(lReturnedValues), "SQLBuilder.InsertWithValues.Values") testutils.AssertNamedArgEqualWithLabel(t, lReturnedValues[0], sql.NamedArg{Name: "id", Value: 1}, "SQLBuilder.InsertWithValues.Values[0]") @@ -86,7 +86,7 @@ func TestUpdate(t *testing.T) { lSQLBuilder := sqlite.Driver().SQLBuilder() lReturnedSQL := lSQLBuilder.Update("product", []string{"sku", "name", "description"}, []string{"id"}) - lExpectedSQL := "update product set sku = @sku,name = @name,description = @description where id = @id" + lExpectedSQL := "update \"product\" set \"sku\" = @sku,\"name\" = @name,\"description\" = @description where \"id\" = @id" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.Update") } @@ -94,7 +94,7 @@ func TestUpdateWithValues(t *testing.T) { lSQLBuilder := sqlite.Driver().SQLBuilder() lReturnedSQL, lReturnedValues := lSQLBuilder.UpdateWithValues("product", []string{"sku", "name"}, []string{"id"}, ormshift.ColumnsValues{"id": 1, "sku": "1.005.12.5", "name": "Trufa Sabor Amarula 18g Cacaushow"}) - lExpectedSQL := "update product set sku = @sku,name = @name where id = @id" + lExpectedSQL := "update \"product\" set \"sku\" = @sku,\"name\" = @name where \"id\" = @id" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.UpdateWithValues.SQL") testutils.AssertEqualWithLabel(t, 3, len(lReturnedValues), "SQLBuilder.UpdateWithValues.Values") testutils.AssertNamedArgEqualWithLabel(t, lReturnedValues[0], sql.NamedArg{Name: "id", Value: 1}, "SQLBuilder.UpdateWithValues.Values[0]") @@ -106,7 +106,7 @@ func TestDelete(t *testing.T) { lSQLBuilder := sqlite.Driver().SQLBuilder() lReturnedSQL := lSQLBuilder.Delete("product", []string{"id"}) - lExpectedSQL := "delete from product where id = @id" + lExpectedSQL := "delete from \"product\" where \"id\" = @id" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.Delete") } @@ -114,7 +114,7 @@ func TestDeleteWithValues(t *testing.T) { lSQLBuilder := sqlite.Driver().SQLBuilder() lReturnedSQL, lReturnedValues := lSQLBuilder.DeleteWithValues("product", ormshift.ColumnsValues{"id": 1}) - lExpectedSQL := "delete from product where id = @id" + lExpectedSQL := "delete from \"product\" where \"id\" = @id" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.DeleteWithValues.SQL") testutils.AssertEqualWithLabel(t, 1, len(lReturnedValues), "SQLBuilder.DeleteWithValues.Values") testutils.AssertNamedArgEqualWithLabel(t, lReturnedValues[0], sql.NamedArg{Name: "id", Value: 1}, "SQLBuilder.DeleteWithValues.Values[0]") @@ -124,7 +124,7 @@ func TestSelect(t *testing.T) { lSQLBuilder := sqlite.Driver().SQLBuilder() lReturnedSQL := lSQLBuilder.Select("product", []string{"id", "name", "description"}, []string{"sku", "active"}) - lExpectedSQL := "select id,name,description from product where sku = @sku and active = @active" + lExpectedSQL := "select \"id\",\"name\",\"description\" from \"product\" where \"sku\" = @sku and \"active\" = @active" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.Select") } @@ -132,7 +132,7 @@ func TestSelectWithValues(t *testing.T) { lSQLBuilder := sqlite.Driver().SQLBuilder() lReturnedSQL, lReturnedValues := lSQLBuilder.SelectWithValues("product", []string{"id", "sku", "name", "description"}, ormshift.ColumnsValues{"category_id": 1, "active": true}) - lExpectedSQL := "select id,sku,name,description from product where active = @active and category_id = @category_id" + lExpectedSQL := "select \"id\",\"sku\",\"name\",\"description\" from \"product\" where \"active\" = @active and \"category_id\" = @category_id" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.SelectWithValues.SQL") testutils.AssertEqualWithLabel(t, 2, len(lReturnedValues), "SQLBuilder.SelectWithValues.Values") testutils.AssertNamedArgEqualWithLabel(t, lReturnedValues[0], sql.NamedArg{Name: "active", Value: true}, "SQLBuilder.SelectWithValues.Values[0]") diff --git a/dialects/sqlite/driver.go b/dialects/sqlite/driver.go index 9c5c808..488aaa4 100644 --- a/dialects/sqlite/driver.go +++ b/dialects/sqlite/driver.go @@ -11,17 +11,21 @@ import ( "github.com/ordershift/ormshift/schema" ) -type sqliteDriver struct{} +type sqliteDriver struct { + sqlBuilder ormshift.SQLBuilder +} func Driver() ormshift.DatabaseDriver { - return sqliteDriver{} + return &sqliteDriver{ + sqlBuilder: newSQLiteBuilder(), + } } -func (d sqliteDriver) Name() string { +func (d *sqliteDriver) Name() string { return "sqlite" } -func (d sqliteDriver) ConnectionString(pParams ormshift.ConnectionParams) string { +func (d *sqliteDriver) ConnectionString(pParams ormshift.ConnectionParams) string { if pParams.InMemory { return ":memory:" } @@ -35,10 +39,10 @@ func (d sqliteDriver) ConnectionString(pParams ormshift.ConnectionParams) string return fmt.Sprintf("file:%s.db?%s_locking=NORMAL", pParams.Database, lConnetionWithAuth) } -func (d sqliteDriver) SQLBuilder() ormshift.SQLBuilder { - return newSQLiteBuilder() +func (d *sqliteDriver) SQLBuilder() ormshift.SQLBuilder { + return d.sqlBuilder } -func (d sqliteDriver) DBSchema(pDB *sql.DB) (*schema.DBSchema, error) { - return schema.NewDBSchema(pDB, tableNamesQuery) +func (d *sqliteDriver) DBSchema(pDB *sql.DB) (*schema.DBSchema, error) { + return schema.NewDBSchema(pDB, tableNamesQuery, columnTypesQueryFunc(d.sqlBuilder)) } diff --git a/dialects/sqlite/schema.go b/dialects/sqlite/schema.go index fa284eb..9519537 100644 --- a/dialects/sqlite/schema.go +++ b/dialects/sqlite/schema.go @@ -1,5 +1,11 @@ package sqlite +import ( + "fmt" + + "github.com/ordershift/ormshift" +) + const tableNamesQuery = ` SELECT name @@ -10,3 +16,9 @@ const tableNamesQuery = ` ORDER BY name ` + +func columnTypesQueryFunc(pSQLBuilder ormshift.SQLBuilder) func(string) string { + return func(pTableName string) string { + return fmt.Sprintf("SELECT * FROM %s WHERE 1=0", pSQLBuilder.QuoteIdentifier(pTableName)) + } +} diff --git a/dialects/sqlite/schema_test.go b/dialects/sqlite/schema_test.go new file mode 100644 index 0000000..2718ff7 --- /dev/null +++ b/dialects/sqlite/schema_test.go @@ -0,0 +1,33 @@ +package sqlite_test + +import ( + "testing" + + "github.com/ordershift/ormshift" + "github.com/ordershift/ormshift/dialects/sqlite" + "github.com/ordershift/ormshift/internal/testutils" +) + +func TestHasTable(t *testing.T) { + lDB, lError := ormshift.OpenDatabase(sqlite.Driver(), ormshift.ConnectionParams{InMemory: true}) + if !testutils.AssertNotNilResultAndNilError(t, lDB, lError, "ormshift.OpenDatabase") { + return + } + defer func() { _ = lDB.Close() }() + + if lDB.DBSchema().HasTable("user") { + t.Errorf("Expected HasTable('user') to be false") + } +} + +func TestHasColumn(t *testing.T) { + lDB, lError := ormshift.OpenDatabase(sqlite.Driver(), ormshift.ConnectionParams{InMemory: true}) + if !testutils.AssertNotNilResultAndNilError(t, lDB, lError, "ormshift.OpenDatabase") { + return + } + defer func() { _ = lDB.Close() }() + + if lDB.DBSchema().HasColumn("user", "id") { + t.Errorf("Expected HasColumn('user', 'id') to be false") + } +} diff --git a/dialects/sqlserver/builder.go b/dialects/sqlserver/builder.go index 5227a83..12a9871 100644 --- a/dialects/sqlserver/builder.go +++ b/dialects/sqlserver/builder.go @@ -3,6 +3,7 @@ package sqlserver import ( "database/sql" "fmt" + "strings" "github.com/ordershift/ormshift" "github.com/ordershift/ormshift/internal" @@ -14,12 +15,12 @@ type sqlserverBuilder struct { } func newSQLServerBuilder() ormshift.SQLBuilder { - lBuilder := sqlserverBuilder{} - lBuilder.generic = internal.NewGenericSQLBuilder(lBuilder.columnDefinition, nil) - return lBuilder + sb := sqlserverBuilder{} + sb.generic = internal.NewGenericSQLBuilder(sb.columnDefinition, sb.QuoteIdentifier, nil) + return &sb } -func (sb sqlserverBuilder) CreateTable(pTable schema.Table) string { +func (sb *sqlserverBuilder) CreateTable(pTable schema.Table) string { lColumns := "" lPKColumns := "" for _, lColumn := range pTable.Columns() { @@ -32,7 +33,7 @@ func (sb sqlserverBuilder) CreateTable(pTable schema.Table) string { if lPKColumns != "" { lPKColumns += "," } - lPKColumns += lColumn.Name().String() + lPKColumns += sb.QuoteIdentifier(lColumn.Name()) } } @@ -40,24 +41,25 @@ func (sb sqlserverBuilder) CreateTable(pTable schema.Table) string { if lColumns != "" { lColumns += "," } - lColumns += fmt.Sprintf("CONSTRAINT PK_%s PRIMARY KEY (%s)", pTable.Name().String(), lPKColumns) + lPKConstraintName := sb.QuoteIdentifier("PK_" + pTable.Name()) + lColumns += fmt.Sprintf("CONSTRAINT %s PRIMARY KEY (%s)", lPKConstraintName, lPKColumns) } - return fmt.Sprintf("CREATE TABLE %s (%s);", pTable.Name().String(), lColumns) + return fmt.Sprintf("CREATE TABLE %s (%s);", sb.QuoteIdentifier(pTable.Name()), lColumns) } -func (sb sqlserverBuilder) DropTable(pTableName schema.TableName) string { +func (sb *sqlserverBuilder) DropTable(pTableName string) string { return sb.generic.DropTable(pTableName) } -func (sb sqlserverBuilder) AlterTableAddColumn(pTableName schema.TableName, pColumn schema.Column) string { +func (sb *sqlserverBuilder) AlterTableAddColumn(pTableName string, pColumn schema.Column) string { return sb.generic.AlterTableAddColumn(pTableName, pColumn) } -func (sb sqlserverBuilder) AlterTableDropColumn(pTableName schema.TableName, pColumnName schema.ColumnName) string { +func (sb *sqlserverBuilder) AlterTableDropColumn(pTableName, pColumnName string) string { return sb.generic.AlterTableDropColumn(pTableName, pColumnName) } -func (sb sqlserverBuilder) ColumnTypeAsString(pColumnType schema.ColumnType) string { +func (sb *sqlserverBuilder) ColumnTypeAsString(pColumnType schema.ColumnType) string { switch pColumnType { case schema.Varchar: return "VARCHAR" @@ -78,8 +80,8 @@ func (sb sqlserverBuilder) ColumnTypeAsString(pColumnType schema.ColumnType) str } } -func (sb sqlserverBuilder) columnDefinition(pColumn schema.Column) string { - lColumnDef := pColumn.Name().String() +func (sb *sqlserverBuilder) columnDefinition(pColumn schema.Column) string { + lColumnDef := sb.QuoteIdentifier(pColumn.Name()) if pColumn.Type() == schema.Varchar { lColumnDef += fmt.Sprintf(" %s(%d)", sb.ColumnTypeAsString(pColumn.Type()), pColumn.Size()) } else { @@ -94,39 +96,39 @@ func (sb sqlserverBuilder) columnDefinition(pColumn schema.Column) string { return lColumnDef } -func (sb sqlserverBuilder) Insert(pTableName string, pColumns []string) string { +func (sb *sqlserverBuilder) Insert(pTableName string, pColumns []string) string { return sb.generic.Insert(pTableName, pColumns) } -func (sb sqlserverBuilder) InsertWithValues(pTableName string, pColumnsValues ormshift.ColumnsValues) (string, []any) { +func (sb *sqlserverBuilder) InsertWithValues(pTableName string, pColumnsValues ormshift.ColumnsValues) (string, []any) { return sb.generic.InsertWithValues(pTableName, pColumnsValues) } -func (sb sqlserverBuilder) Update(pTableName string, pColumns, pColumnsWhere []string) string { +func (sb *sqlserverBuilder) Update(pTableName string, pColumns, pColumnsWhere []string) string { return sb.generic.Update(pTableName, pColumns, pColumnsWhere) } -func (sb sqlserverBuilder) UpdateWithValues(pTableName string, pColumns, pColumnsWhere []string, pValues ormshift.ColumnsValues) (string, []any) { +func (sb *sqlserverBuilder) UpdateWithValues(pTableName string, pColumns, pColumnsWhere []string, pValues ormshift.ColumnsValues) (string, []any) { return sb.generic.UpdateWithValues(pTableName, pColumns, pColumnsWhere, pValues) } -func (sb sqlserverBuilder) Delete(pTableName string, pColumnsWhere []string) string { +func (sb *sqlserverBuilder) Delete(pTableName string, pColumnsWhere []string) string { return sb.generic.Delete(pTableName, pColumnsWhere) } -func (sb sqlserverBuilder) DeleteWithValues(pTableName string, pWhereColumnsValues ormshift.ColumnsValues) (string, []any) { +func (sb *sqlserverBuilder) DeleteWithValues(pTableName string, pWhereColumnsValues ormshift.ColumnsValues) (string, []any) { return sb.generic.DeleteWithValues(pTableName, pWhereColumnsValues) } -func (sb sqlserverBuilder) Select(pTableName string, pColumns, pColumnsWhere []string) string { +func (sb *sqlserverBuilder) Select(pTableName string, pColumns, pColumnsWhere []string) string { return sb.generic.Select(pTableName, pColumns, pColumnsWhere) } -func (sb sqlserverBuilder) SelectWithValues(pTableName string, pColumns []string, pWhereColumnsValues ormshift.ColumnsValues) (string, []any) { +func (sb *sqlserverBuilder) SelectWithValues(pTableName string, pColumns []string, pWhereColumnsValues ormshift.ColumnsValues) (string, []any) { return sb.generic.SelectWithValues(pTableName, pColumns, pWhereColumnsValues) } -func (sb sqlserverBuilder) SelectWithPagination(pSQLSelectCommand string, pRowsPerPage, pPageNumber uint) string { +func (sb *sqlserverBuilder) SelectWithPagination(pSQLSelectCommand string, pRowsPerPage, pPageNumber uint) string { lSelectWithPagination := pSQLSelectCommand if pRowsPerPage > 0 { lOffSet := uint(0) @@ -138,6 +140,14 @@ func (sb sqlserverBuilder) SelectWithPagination(pSQLSelectCommand string, pRowsP return lSelectWithPagination } -func (sb sqlserverBuilder) InteroperateSQLCommandWithNamedArgs(pSQLCommand string, pNamedArgs ...sql.NamedArg) (string, []any) { +func (sb *sqlserverBuilder) QuoteIdentifier(pIdentifier string) string { + // SQL Server uses square brackets: [identifier] + // Escape rule: ] becomes ]] + // Example: users -> [users], table]name -> [table]]name] + pIdentifier = strings.ReplaceAll(pIdentifier, "]", "]]") + return fmt.Sprintf("[%s]", pIdentifier) +} + +func (sb *sqlserverBuilder) InteroperateSQLCommandWithNamedArgs(pSQLCommand string, pNamedArgs ...sql.NamedArg) (string, []any) { return sb.generic.InteroperateSQLCommandWithNamedArgs(pSQLCommand, pNamedArgs...) } diff --git a/dialects/sqlserver/builder_test.go b/dialects/sqlserver/builder_test.go index 30a811c..37fe538 100644 --- a/dialects/sqlserver/builder_test.go +++ b/dialects/sqlserver/builder_test.go @@ -22,15 +22,15 @@ func TestCreateTable(t *testing.T) { lSQLBuilder := sqlserver.Driver().SQLBuilder() lUserTable := testutils.FakeUserTable(t) - lExpectedSQL := "CREATE TABLE user (id BIGINT NOT NULL IDENTITY (1, 1),email VARCHAR(80) NOT NULL,name VARCHAR(50) NOT NULL," + - "password_hash VARCHAR(256),active BIT,created_at DATETIME2(6),user_master BIGINT,master_user_id BIGINT," + - "licence_price MONEY,relevance FLOAT,photo VARBINARY(MAX),any VARCHAR,CONSTRAINT PK_user PRIMARY KEY (id,email));" - lReturnedSQL := lSQLBuilder.CreateTable(*lUserTable) + lExpectedSQL := "CREATE TABLE [user] ([id] BIGINT NOT NULL IDENTITY (1, 1),[email] VARCHAR(80) NOT NULL,[name] VARCHAR(50) NOT NULL," + + "[password_hash] VARCHAR(256),[active] BIT,[created_at] DATETIME2(6),[user_master] BIGINT,[master_user_id] BIGINT," + + "[licence_price] MONEY,[relevance] FLOAT,[photo] VARBINARY(MAX),[any] VARCHAR,CONSTRAINT [PK_user] PRIMARY KEY ([id],[email]));" + lReturnedSQL := lSQLBuilder.CreateTable(lUserTable) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.CreateTable") lProductAttributeTable := testutils.FakeProductAttributeTable(t) - lExpectedSQL = "CREATE TABLE product_attribute (product_id BIGINT NOT NULL,attribute_id BIGINT NOT NULL,value VARCHAR(75),position BIGINT,CONSTRAINT PK_product_attribute PRIMARY KEY (product_id,attribute_id));" - lReturnedSQL = lSQLBuilder.CreateTable(*lProductAttributeTable) + lExpectedSQL = "CREATE TABLE [product_attribute] ([product_id] BIGINT NOT NULL,[attribute_id] BIGINT NOT NULL,[value] VARCHAR(75),[position] BIGINT,CONSTRAINT [PK_product_attribute] PRIMARY KEY ([product_id],[attribute_id]));" + lReturnedSQL = lSQLBuilder.CreateTable(lProductAttributeTable) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.CreateTable") } @@ -38,8 +38,8 @@ func TestDropTable(t *testing.T) { lSQLBuilder := sqlserver.Driver().SQLBuilder() lUserTableName := testutils.FakeUserTableName(t) - lExpectedSQL := "DROP TABLE user;" - lReturnedSQL := lSQLBuilder.DropTable(*lUserTableName) + lExpectedSQL := "DROP TABLE [user];" + lReturnedSQL := lSQLBuilder.DropTable(lUserTableName) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.DropTable") } @@ -48,8 +48,8 @@ func TestAlterTableAddColumn(t *testing.T) { lUserTableName := testutils.FakeUserTableName(t) lUpdatedAtColumn := testutils.FakeUpdatedAtColumn(t) - lExpectedSQL := "ALTER TABLE user ADD COLUMN updated_at DATETIME2(6);" - lReturnedSQL := lSQLBuilder.AlterTableAddColumn(*lUserTableName, *lUpdatedAtColumn) + lExpectedSQL := "ALTER TABLE [user] ADD COLUMN [updated_at] DATETIME2(6);" + lReturnedSQL := lSQLBuilder.AlterTableAddColumn(lUserTableName, lUpdatedAtColumn) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.AlterTableAddColumn") } @@ -58,8 +58,8 @@ func TestAlterTableDropColumn(t *testing.T) { lUserTableName := testutils.FakeUserTableName(t) lUpdatedAtColumnName := testutils.FakeUpdatedAtColumnName(t) - lExpectedSQL := "ALTER TABLE user DROP COLUMN updated_at;" - lReturnedSQL := lSQLBuilder.AlterTableDropColumn(*lUserTableName, *lUpdatedAtColumnName) + lExpectedSQL := "ALTER TABLE [user] DROP COLUMN [updated_at];" + lReturnedSQL := lSQLBuilder.AlterTableDropColumn(lUserTableName, lUpdatedAtColumnName) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.AlterTableDropColumn") } @@ -67,7 +67,7 @@ func TestInsert(t *testing.T) { lSQLBuilder := sqlserver.Driver().SQLBuilder() lReturnedSQL := lSQLBuilder.Insert("product", []string{"id", "sku", "name", "description"}) - lExpectedSQL := "insert into product (id,sku,name,description) values (@id,@sku,@name,@description)" + lExpectedSQL := "insert into [product] ([id],[sku],[name],[description]) values (@id,@sku,@name,@description)" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.Insert") } @@ -75,7 +75,7 @@ func TestInsertWithValues(t *testing.T) { lSQLBuilder := sqlserver.Driver().SQLBuilder() lReturnedSQL, lReturnedValues := lSQLBuilder.InsertWithValues("product", ormshift.ColumnsValues{"id": 1, "sku": "1.005.12.9", "name": "Trufa Sabor Amarula 30g Cacaushow"}) - lExpectedSQL := "insert into product (id,name,sku) values (@id,@name,@sku)" + lExpectedSQL := "insert into [product] ([id],[name],[sku]) values (@id,@name,@sku)" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.InsertWithValues.SQL") testutils.AssertEqualWithLabel(t, 3, len(lReturnedValues), "SQLBuilder.InsertWithValues.Values") testutils.AssertNamedArgEqualWithLabel(t, lReturnedValues[0], sql.NamedArg{Name: "id", Value: 1}, "SQLBuilder.InsertWithValues.Values[0]") @@ -87,7 +87,7 @@ func TestUpdate(t *testing.T) { lSQLBuilder := sqlserver.Driver().SQLBuilder() lReturnedSQL := lSQLBuilder.Update("product", []string{"sku", "name", "description"}, []string{"id"}) - lExpectedSQL := "update product set sku = @sku,name = @name,description = @description where id = @id" + lExpectedSQL := "update [product] set [sku] = @sku,[name] = @name,[description] = @description where [id] = @id" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.Update") } @@ -95,7 +95,7 @@ func TestUpdateWithValues(t *testing.T) { lSQLBuilder := sqlserver.Driver().SQLBuilder() lReturnedSQL, lReturnedValues := lSQLBuilder.UpdateWithValues("product", []string{"sku", "name"}, []string{"id"}, ormshift.ColumnsValues{"id": 1, "sku": "1.005.12.5", "name": "Trufa Sabor Amarula 18g Cacaushow"}) - lExpectedSQL := "update product set sku = @sku,name = @name where id = @id" + lExpectedSQL := "update [product] set [sku] = @sku,[name] = @name where [id] = @id" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.UpdateWithValues.SQL") testutils.AssertEqualWithLabel(t, 3, len(lReturnedValues), "SQLBuilder.UpdateWithValues.Values") testutils.AssertNamedArgEqualWithLabel(t, lReturnedValues[0], sql.NamedArg{Name: "id", Value: 1}, "SQLBuilder.UpdateWithValues.Values[0]") @@ -107,7 +107,7 @@ func TestDelete(t *testing.T) { lSQLBuilder := sqlserver.Driver().SQLBuilder() lReturnedSQL := lSQLBuilder.Delete("product", []string{"id"}) - lExpectedSQL := "delete from product where id = @id" + lExpectedSQL := "delete from [product] where [id] = @id" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.Delete") } @@ -115,7 +115,7 @@ func TestDeleteWithValues(t *testing.T) { lSQLBuilder := sqlserver.Driver().SQLBuilder() lReturnedSQL, lReturnedValues := lSQLBuilder.DeleteWithValues("product", ormshift.ColumnsValues{"id": 1}) - lExpectedSQL := "delete from product where id = @id" + lExpectedSQL := "delete from [product] where [id] = @id" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.DeleteWithValues.SQL") testutils.AssertEqualWithLabel(t, 1, len(lReturnedValues), "SQLBuilder.DeleteWithValues.Values") testutils.AssertNamedArgEqualWithLabel(t, lReturnedValues[0], sql.NamedArg{Name: "id", Value: 1}, "SQLBuilder.DeleteWithValues.Values[0]") @@ -125,7 +125,7 @@ func TestSelect(t *testing.T) { lSQLBuilder := sqlserver.Driver().SQLBuilder() lReturnedSQL := lSQLBuilder.Select("product", []string{"id", "name", "description"}, []string{"sku", "active"}) - lExpectedSQL := "select id,name,description from product where sku = @sku and active = @active" + lExpectedSQL := "select [id],[name],[description] from [product] where [sku] = @sku and [active] = @active" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.Select") } @@ -133,7 +133,7 @@ func TestSelectWithValues(t *testing.T) { lSQLBuilder := sqlserver.Driver().SQLBuilder() lReturnedSQL, lReturnedValues := lSQLBuilder.SelectWithValues("product", []string{"id", "sku", "name", "description"}, ormshift.ColumnsValues{"category_id": 1, "active": true}) - lExpectedSQL := "select id,sku,name,description from product where active = @active and category_id = @category_id" + lExpectedSQL := "select [id],[sku],[name],[description] from [product] where [active] = @active and [category_id] = @category_id" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.SelectWithValues.SQL") testutils.AssertEqualWithLabel(t, 2, len(lReturnedValues), "SQLBuilder.SelectWithValues.Values") testutils.AssertNamedArgEqualWithLabel(t, lReturnedValues[0], sql.NamedArg{Name: "active", Value: true}, "SQLBuilder.SelectWithValues.Values[0]") diff --git a/dialects/sqlserver/driver.go b/dialects/sqlserver/driver.go index d63ff96..e415ade 100644 --- a/dialects/sqlserver/driver.go +++ b/dialects/sqlserver/driver.go @@ -11,17 +11,21 @@ import ( "github.com/ordershift/ormshift/schema" ) -type sqlserverDriver struct{} +type sqlserverDriver struct { + sqlBuilder ormshift.SQLBuilder +} func Driver() ormshift.DatabaseDriver { - return sqlserverDriver{} + return &sqlserverDriver{ + sqlBuilder: newSQLServerBuilder(), + } } -func (d sqlserverDriver) Name() string { +func (d *sqlserverDriver) Name() string { return "sqlserver" } -func (d sqlserverDriver) ConnectionString(pParams ormshift.ConnectionParams) string { +func (d *sqlserverDriver) ConnectionString(pParams ormshift.ConnectionParams) string { lHostInstanceAndPort := pParams.Host if pParams.Instance != "" { lHostInstanceAndPort = fmt.Sprintf("%s\\%s", pParams.Host, pParams.Instance) @@ -32,10 +36,10 @@ func (d sqlserverDriver) ConnectionString(pParams ormshift.ConnectionParams) str return fmt.Sprintf("server=%s;user id=%s;password=%s;database=%s", lHostInstanceAndPort, pParams.User, pParams.Password, pParams.Database) } -func (d sqlserverDriver) SQLBuilder() ormshift.SQLBuilder { - return newSQLServerBuilder() +func (d *sqlserverDriver) SQLBuilder() ormshift.SQLBuilder { + return d.sqlBuilder } -func (d sqlserverDriver) DBSchema(pDB *sql.DB) (*schema.DBSchema, error) { - return schema.NewDBSchema(pDB, tableNamesQuery) +func (d *sqlserverDriver) DBSchema(pDB *sql.DB) (*schema.DBSchema, error) { + return schema.NewDBSchema(pDB, tableNamesQuery, columnTypesQueryFunc(d.sqlBuilder)) } diff --git a/dialects/sqlserver/schema.go b/dialects/sqlserver/schema.go index d4019cf..764086c 100644 --- a/dialects/sqlserver/schema.go +++ b/dialects/sqlserver/schema.go @@ -1,5 +1,11 @@ package sqlserver +import ( + "fmt" + + "github.com/ordershift/ormshift" +) + const tableNamesQuery = ` SELECT t.name @@ -15,3 +21,9 @@ const tableNamesQuery = ` ORDER BY t.name ` + +func columnTypesQueryFunc(pSQLBuilder ormshift.SQLBuilder) func(string) string { + return func(pTableName string) string { + return fmt.Sprintf("SELECT * FROM %s WHERE 1=0", pSQLBuilder.QuoteIdentifier(pTableName)) + } +} diff --git a/dialects/sqlserver/schema_test.go b/dialects/sqlserver/schema_test.go new file mode 100644 index 0000000..db45dcf --- /dev/null +++ b/dialects/sqlserver/schema_test.go @@ -0,0 +1,33 @@ +package sqlserver_test + +import ( + "testing" + + "github.com/ordershift/ormshift" + "github.com/ordershift/ormshift/dialects/sqlserver" + "github.com/ordershift/ormshift/internal/testutils" +) + +func TestHasTable(t *testing.T) { + lDB, lError := ormshift.OpenDatabase(sqlserver.Driver(), ormshift.ConnectionParams{InMemory: true}) + if !testutils.AssertNotNilResultAndNilError(t, lDB, lError, "ormshift.OpenDatabase") { + return + } + defer func() { _ = lDB.Close() }() + + if lDB.DBSchema().HasTable("user") { + t.Errorf("Expected HasTable('user') to be false") + } +} + +func TestHasColumn(t *testing.T) { + lDB, lError := ormshift.OpenDatabase(sqlserver.Driver(), ormshift.ConnectionParams{InMemory: true}) + if !testutils.AssertNotNilResultAndNilError(t, lDB, lError, "ormshift.OpenDatabase") { + return + } + defer func() { _ = lDB.Close() }() + + if lDB.DBSchema().HasColumn("user", "id") { + t.Errorf("Expected HasColumn('user', 'id') to be false") + } +} diff --git a/internal/builder_generic.go b/internal/builder_generic.go index 858428d..1f061bd 100644 --- a/internal/builder_generic.go +++ b/internal/builder_generic.go @@ -11,21 +11,29 @@ import ( type ColumnDefinitionFunc func(schema.Column) string +type QuoteIdentifierFunc func(string) string + type InteroperateSQLCommandWithNamedArgsFunc func(pSQLCommand string, pNamedArgs ...sql.NamedArg) (string, []any) type genericSQLBuilder struct { ColumnDefinitionFunc ColumnDefinitionFunc InteroperateSQLCommandWithNamedArgsFunc InteroperateSQLCommandWithNamedArgsFunc + QuoteIdentifierFunc QuoteIdentifierFunc } -func NewGenericSQLBuilder(pColumnDefinitionFunc ColumnDefinitionFunc, pInteroperateSQLCommandWithNamedArgsFunc InteroperateSQLCommandWithNamedArgsFunc) ormshift.SQLBuilder { - return genericSQLBuilder{ +func NewGenericSQLBuilder( + pColumnDefinitionFunc ColumnDefinitionFunc, + pQuoteIdentifierFunc QuoteIdentifierFunc, + pInteroperateSQLCommandWithNamedArgsFunc InteroperateSQLCommandWithNamedArgsFunc, +) ormshift.SQLBuilder { + return &genericSQLBuilder{ ColumnDefinitionFunc: pColumnDefinitionFunc, + QuoteIdentifierFunc: pQuoteIdentifierFunc, InteroperateSQLCommandWithNamedArgsFunc: pInteroperateSQLCommandWithNamedArgsFunc, } } -func (sb genericSQLBuilder) CreateTable(pTable schema.Table) string { +func (sb *genericSQLBuilder) CreateTable(pTable schema.Table) string { lColumns := "" lPKColumns := "" for _, lColumn := range pTable.Columns() { @@ -38,7 +46,7 @@ func (sb genericSQLBuilder) CreateTable(pTable schema.Table) string { if lPKColumns != "" { lPKColumns += "," } - lPKColumns += lColumn.Name().String() + lPKColumns += sb.QuoteIdentifier(lColumn.Name()) } } @@ -48,86 +56,86 @@ func (sb genericSQLBuilder) CreateTable(pTable schema.Table) string { } lColumns += fmt.Sprintf("PRIMARY KEY (%s)", lPKColumns) } - return fmt.Sprintf("CREATE TABLE %s (%s);", pTable.Name().String(), lColumns) + return fmt.Sprintf("CREATE TABLE %s (%s);", sb.QuoteIdentifier(pTable.Name()), lColumns) } -func (sb genericSQLBuilder) DropTable(pTableName schema.TableName) string { - return fmt.Sprintf("DROP TABLE %s;", pTableName.String()) +func (sb *genericSQLBuilder) DropTable(pTableName string) string { + return fmt.Sprintf("DROP TABLE %s;", sb.QuoteIdentifier(pTableName)) } -func (sb genericSQLBuilder) AlterTableAddColumn(pTableName schema.TableName, pColumn schema.Column) string { - return fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s;", pTableName.String(), sb.columnDefinition(pColumn)) +func (sb *genericSQLBuilder) AlterTableAddColumn(pTableName string, pColumn schema.Column) string { + return fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s;", sb.QuoteIdentifier(pTableName), sb.columnDefinition(pColumn)) } -func (sb genericSQLBuilder) AlterTableDropColumn(pTableName schema.TableName, pColumnName schema.ColumnName) string { - return fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s;", pTableName.String(), pColumnName.String()) +func (sb *genericSQLBuilder) AlterTableDropColumn(pTableName, pColumnName string) string { + return fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s;", sb.QuoteIdentifier(pTableName), sb.QuoteIdentifier(pColumnName)) } -func (sb genericSQLBuilder) ColumnTypeAsString(pColumnType schema.ColumnType) string { +func (sb *genericSQLBuilder) ColumnTypeAsString(pColumnType schema.ColumnType) string { // Generic implementation, should be overridden by specific SQL builders return fmt.Sprintf("<>", pColumnType) } -func (sb genericSQLBuilder) columnDefinition(pColumn schema.Column) string { +func (sb *genericSQLBuilder) columnDefinition(pColumn schema.Column) string { if sb.ColumnDefinitionFunc != nil { return sb.ColumnDefinitionFunc(pColumn) } - return fmt.Sprintf("%s %s", pColumn.Name().String(), sb.ColumnTypeAsString(pColumn.Type())) + return fmt.Sprintf("%s %s", sb.QuoteIdentifier(pColumn.Name()), sb.ColumnTypeAsString(pColumn.Type())) } -func (sb genericSQLBuilder) Insert(pTableName string, pColumns []string) string { - return fmt.Sprintf("insert into %s (%s) values (%s)", pTableName, sb.columnsList(pColumns), sb.namesList(pColumns)) +func (sb *genericSQLBuilder) Insert(pTableName string, pColumns []string) string { + return fmt.Sprintf("insert into %s (%s) values (%s)", sb.QuoteIdentifier(pTableName), sb.columnsList(pColumns), sb.namesList(pColumns)) } -func (sb genericSQLBuilder) InsertWithValues(pTableName string, pColumnsValues ormshift.ColumnsValues) (string, []any) { +func (sb *genericSQLBuilder) InsertWithValues(pTableName string, pColumnsValues ormshift.ColumnsValues) (string, []any) { lInsertSQL := sb.Insert(pTableName, pColumnsValues.ToColumns()) lInsertArgs := pColumnsValues.ToNamedArgs() return sb.InteroperateSQLCommandWithNamedArgs(lInsertSQL, lInsertArgs...) } -func (sb genericSQLBuilder) Update(pTableName string, pColumns, pColumnsWhere []string) string { - lUpdate := fmt.Sprintf("update %s set %s ", pTableName, sb.columnEqualNameList(pColumns, ",")) +func (sb *genericSQLBuilder) Update(pTableName string, pColumns, pColumnsWhere []string) string { + lUpdate := fmt.Sprintf("update %s set %s ", sb.QuoteIdentifier(pTableName), sb.columnEqualNameList(pColumns, ",")) if len(pColumnsWhere) > 0 { lUpdate += fmt.Sprintf("where %s", sb.columnEqualNameList(pColumnsWhere, " and ")) // NOSONAR go:S1192 - duplicate tradeoff accepted } return lUpdate } -func (sb genericSQLBuilder) UpdateWithValues(pTableName string, pColumns, pColumnsWhere []string, pValues ormshift.ColumnsValues) (string, []any) { +func (sb *genericSQLBuilder) UpdateWithValues(pTableName string, pColumns, pColumnsWhere []string, pValues ormshift.ColumnsValues) (string, []any) { lUpdateSQL := sb.Update(pTableName, pColumns, pColumnsWhere) lUpdateArgs := pValues.ToNamedArgs() return sb.InteroperateSQLCommandWithNamedArgs(lUpdateSQL, lUpdateArgs...) } -func (sb genericSQLBuilder) Delete(pTableName string, pColumnsWhere []string) string { - lDelete := fmt.Sprintf("delete from %s ", pTableName) +func (sb *genericSQLBuilder) Delete(pTableName string, pColumnsWhere []string) string { + lDelete := fmt.Sprintf("delete from %s ", sb.QuoteIdentifier(pTableName)) if len(pColumnsWhere) > 0 { lDelete += fmt.Sprintf("where %s", sb.columnEqualNameList(pColumnsWhere, " and ")) // NOSONAR go:S1192 - duplicate tradeoff accepted } return lDelete } -func (sb genericSQLBuilder) DeleteWithValues(pTableName string, pWhereColumnsValues ormshift.ColumnsValues) (string, []any) { +func (sb *genericSQLBuilder) DeleteWithValues(pTableName string, pWhereColumnsValues ormshift.ColumnsValues) (string, []any) { lDeleteSQL := sb.Delete(pTableName, pWhereColumnsValues.ToColumns()) lDeleteArgs := pWhereColumnsValues.ToNamedArgs() return sb.InteroperateSQLCommandWithNamedArgs(lDeleteSQL, lDeleteArgs...) } -func (sb genericSQLBuilder) Select(pTableName string, pColumns, pColumnsWhere []string) string { - lUpdate := fmt.Sprintf("select %s from %s ", sb.columnsList(pColumns), pTableName) +func (sb *genericSQLBuilder) Select(pTableName string, pColumns, pColumnsWhere []string) string { + lUpdate := fmt.Sprintf("select %s from %s ", sb.columnsList(pColumns), sb.QuoteIdentifier(pTableName)) if len(pColumnsWhere) > 0 { lUpdate += fmt.Sprintf("where %s", sb.columnEqualNameList(pColumnsWhere, " and ")) // NOSONAR go:S1192 - duplicate tradeoff accepted } return lUpdate } -func (sb genericSQLBuilder) SelectWithValues(pTableName string, pColumns []string, pWhereColumnsValues ormshift.ColumnsValues) (string, []any) { +func (sb *genericSQLBuilder) SelectWithValues(pTableName string, pColumns []string, pWhereColumnsValues ormshift.ColumnsValues) (string, []any) { lSelectSQL := sb.Select(pTableName, pColumns, pWhereColumnsValues.ToColumns()) lSelectArgs := pWhereColumnsValues.ToNamedArgs() return sb.InteroperateSQLCommandWithNamedArgs(lSelectSQL, lSelectArgs...) } -func (sb genericSQLBuilder) SelectWithPagination(pSQLSelectCommand string, pRowsPerPage, pPageNumber uint) string { +func (sb *genericSQLBuilder) SelectWithPagination(pSQLSelectCommand string, pRowsPerPage, pPageNumber uint) string { lSelectWithPagination := pSQLSelectCommand if pRowsPerPage > 0 { lSelectWithPagination += fmt.Sprintf(" LIMIT %d", pRowsPerPage) @@ -138,11 +146,15 @@ func (sb genericSQLBuilder) SelectWithPagination(pSQLSelectCommand string, pRows return lSelectWithPagination } -func (sb genericSQLBuilder) columnsList(pColumns []string) string { - return strings.Join(pColumns, ",") +func (sb *genericSQLBuilder) columnsList(pColumns []string) string { + lQuotedColumns := []string{} + for _, col := range pColumns { + lQuotedColumns = append(lQuotedColumns, sb.QuoteIdentifier(col)) + } + return strings.Join(lQuotedColumns, ",") } -func (sb genericSQLBuilder) namesList(pColumns []string) string { +func (sb *genericSQLBuilder) namesList(pColumns []string) string { lNames := []string{} for _, lColumn := range pColumns { lNames = append(lNames, "@"+lColumn) @@ -150,21 +162,34 @@ func (sb genericSQLBuilder) namesList(pColumns []string) string { return strings.Join(lNames, ",") } -func (sb genericSQLBuilder) columnEqualNameList(pColumns []string, pSeparator string) string { +func (sb *genericSQLBuilder) columnEqualNameList(pColumns []string, pSeparator string) string { lColumnEqualNameList := "" for _, lColumn := range pColumns { if lColumnEqualNameList != "" { lColumnEqualNameList += pSeparator } - lColumnEqualNameList += fmt.Sprintf("%s = @%s", lColumn, lColumn) + lColumnEqualNameList += fmt.Sprintf("%s = @%s", sb.QuoteIdentifier(lColumn), lColumn) } return lColumnEqualNameList } -func (sb genericSQLBuilder) InteroperateSQLCommandWithNamedArgs(pSQLCommand string, pNamedArgs ...sql.NamedArg) (string, []any) { +func (sb *genericSQLBuilder) QuoteIdentifier(pIdentifier string) string { + if sb.QuoteIdentifierFunc != nil { + return sb.QuoteIdentifierFunc(pIdentifier) + } + + // Most databases uses double quotes: "identifier" (PostgreSQL, SQLite, etc.) + // Escape rule: double quote becomes two double quotes + // Example: users -> "users", table"name -> "table""name" + pIdentifier = strings.ReplaceAll(pIdentifier, `"`, `""`) + return fmt.Sprintf(`"%s"`, pIdentifier) +} + +func (sb *genericSQLBuilder) InteroperateSQLCommandWithNamedArgs(pSQLCommand string, pNamedArgs ...sql.NamedArg) (string, []any) { if sb.InteroperateSQLCommandWithNamedArgsFunc != nil { return sb.InteroperateSQLCommandWithNamedArgsFunc(pSQLCommand, pNamedArgs...) } + lSQLCommand := pSQLCommand lArgs := []any{} for _, lParam := range pNamedArgs { diff --git a/internal/builder_generic_test.go b/internal/builder_generic_test.go index 2204821..03d1a1c 100644 --- a/internal/builder_generic_test.go +++ b/internal/builder_generic_test.go @@ -11,63 +11,63 @@ import ( ) func TestCreateTable(t *testing.T) { - lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil) + lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil, nil) lUserTable := testutils.FakeUserTable(t) - lExpectedSQL := "CREATE TABLE user (id <>,email <>,name <>,password_hash <>," + - "active <>,created_at <>,user_master <>,master_user_id <>," + - "licence_price <>,relevance <>,photo <>,any <>,PRIMARY KEY (id,email));" - lReturnedSQL := lSQLBuilder.CreateTable(*lUserTable) + lExpectedSQL := "CREATE TABLE \"user\" (\"id\" <>,\"email\" <>,\"name\" <>,\"password_hash\" <>," + + "\"active\" <>,\"created_at\" <>,\"user_master\" <>,\"master_user_id\" <>," + + "\"licence_price\" <>,\"relevance\" <>,\"photo\" <>,\"any\" <>,PRIMARY KEY (\"id\",\"email\"));" + lReturnedSQL := lSQLBuilder.CreateTable(lUserTable) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.CreateTable") lProductAttributeTable := testutils.FakeProductAttributeTable(t) - lExpectedSQL = "CREATE TABLE product_attribute (product_id <>,attribute_id <>,value <>,position <>,PRIMARY KEY (product_id,attribute_id));" - lReturnedSQL = lSQLBuilder.CreateTable(*lProductAttributeTable) + lExpectedSQL = "CREATE TABLE \"product_attribute\" (\"product_id\" <>,\"attribute_id\" <>,\"value\" <>,\"position\" <>,PRIMARY KEY (\"product_id\",\"attribute_id\"));" + lReturnedSQL = lSQLBuilder.CreateTable(lProductAttributeTable) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.CreateTable") } func TestDropTable(t *testing.T) { - lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil) + lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil, nil) lUserTableName := testutils.FakeUserTableName(t) - lExpectedSQL := "DROP TABLE user;" - lReturnedSQL := lSQLBuilder.DropTable(*lUserTableName) + lExpectedSQL := "DROP TABLE \"user\";" + lReturnedSQL := lSQLBuilder.DropTable(lUserTableName) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.DropTable") } func TestAlterTableAddColumn(t *testing.T) { - lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil) + lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil, nil) lUserTableName := testutils.FakeUserTableName(t) lUpdatedAtColumn := testutils.FakeUpdatedAtColumn(t) - lExpectedSQL := "ALTER TABLE user ADD COLUMN updated_at <>;" - lReturnedSQL := lSQLBuilder.AlterTableAddColumn(*lUserTableName, *lUpdatedAtColumn) + lExpectedSQL := "ALTER TABLE \"user\" ADD COLUMN \"updated_at\" <>;" + lReturnedSQL := lSQLBuilder.AlterTableAddColumn(lUserTableName, lUpdatedAtColumn) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.AlterTableAddColumn") } func TestAlterTableDropColumn(t *testing.T) { - lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil) + lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil, nil) lUserTableName := testutils.FakeUserTableName(t) lUpdatedAtColumnName := testutils.FakeUpdatedAtColumnName(t) - lExpectedSQL := "ALTER TABLE user DROP COLUMN updated_at;" - lReturnedSQL := lSQLBuilder.AlterTableDropColumn(*lUserTableName, *lUpdatedAtColumnName) + lExpectedSQL := "ALTER TABLE \"user\" DROP COLUMN \"updated_at\";" + lReturnedSQL := lSQLBuilder.AlterTableDropColumn(lUserTableName, lUpdatedAtColumnName) testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.AlterTableDropColumn") } func TestInsert(t *testing.T) { - lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil) + lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil, nil) lReturnedSQL := lSQLBuilder.Insert("product", []string{"id", "sku", "name", "description"}) - lExpectedSQL := "insert into product (id,sku,name,description) values (@id,@sku,@name,@description)" + lExpectedSQL := "insert into \"product\" (\"id\",\"sku\",\"name\",\"description\") values (@id,@sku,@name,@description)" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.Insert") } func TestInsertWithValues(t *testing.T) { - lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil) + lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil, nil) lReturnedSQL, lReturnedValues := lSQLBuilder.InsertWithValues("product", ormshift.ColumnsValues{"id": 1, "sku": "1.005.12.9", "name": "Trufa Sabor Amarula 30g Cacaushow"}) - lExpectedSQL := "insert into product (id,name,sku) values (@id,@name,@sku)" + lExpectedSQL := "insert into \"product\" (\"id\",\"name\",\"sku\") values (@id,@name,@sku)" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.InsertWithValues.SQL") testutils.AssertEqualWithLabel(t, 3, len(lReturnedValues), "SQLBuilder.InsertWithValues.Values") testutils.AssertNamedArgEqualWithLabel(t, lReturnedValues[0], sql.NamedArg{Name: "id", Value: 1}, "SQLBuilder.InsertWithValues.Values[0]") @@ -76,18 +76,18 @@ func TestInsertWithValues(t *testing.T) { } func TestUpdate(t *testing.T) { - lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil) + lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil, nil) lReturnedSQL := lSQLBuilder.Update("product", []string{"sku", "name", "description"}, []string{"id"}) - lExpectedSQL := "update product set sku = @sku,name = @name,description = @description where id = @id" + lExpectedSQL := "update \"product\" set \"sku\" = @sku,\"name\" = @name,\"description\" = @description where \"id\" = @id" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.Update") } func TestUpdateWithValues(t *testing.T) { - lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil) + lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil, nil) lReturnedSQL, lReturnedValues := lSQLBuilder.UpdateWithValues("product", []string{"sku", "name"}, []string{"id"}, ormshift.ColumnsValues{"id": 1, "sku": "1.005.12.5", "name": "Trufa Sabor Amarula 18g Cacaushow"}) - lExpectedSQL := "update product set sku = @sku,name = @name where id = @id" + lExpectedSQL := "update \"product\" set \"sku\" = @sku,\"name\" = @name where \"id\" = @id" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.UpdateWithValues.SQL") testutils.AssertEqualWithLabel(t, 3, len(lReturnedValues), "SQLBuilder.UpdateWithValues.Values") testutils.AssertNamedArgEqualWithLabel(t, lReturnedValues[0], sql.NamedArg{Name: "id", Value: 1}, "SQLBuilder.UpdateWithValues.Values[0]") @@ -96,36 +96,36 @@ func TestUpdateWithValues(t *testing.T) { } func TestDelete(t *testing.T) { - lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil) + lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil, nil) lReturnedSQL := lSQLBuilder.Delete("product", []string{"id"}) - lExpectedSQL := "delete from product where id = @id" + lExpectedSQL := "delete from \"product\" where \"id\" = @id" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.Delete") } func TestDeleteWithValues(t *testing.T) { - lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil) + lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil, nil) lReturnedSQL, lReturnedValues := lSQLBuilder.DeleteWithValues("product", ormshift.ColumnsValues{"id": 1}) - lExpectedSQL := "delete from product where id = @id" + lExpectedSQL := "delete from \"product\" where \"id\" = @id" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.DeleteWithValues.SQL") testutils.AssertEqualWithLabel(t, 1, len(lReturnedValues), "SQLBuilder.DeleteWithValues.Values") testutils.AssertNamedArgEqualWithLabel(t, lReturnedValues[0], sql.NamedArg{Name: "id", Value: 1}, "SQLBuilder.DeleteWithValues.Values[0]") } func TestSelect(t *testing.T) { - lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil) + lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil, nil) lReturnedSQL := lSQLBuilder.Select("product", []string{"id", "name", "description"}, []string{"sku", "active"}) - lExpectedSQL := "select id,name,description from product where sku = @sku and active = @active" + lExpectedSQL := "select \"id\",\"name\",\"description\" from \"product\" where \"sku\" = @sku and \"active\" = @active" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.Select") } func TestSelectWithValues(t *testing.T) { - lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil) + lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil, nil) lReturnedSQL, lReturnedValues := lSQLBuilder.SelectWithValues("product", []string{"id", "sku", "name", "description"}, ormshift.ColumnsValues{"category_id": 1, "active": true}) - lExpectedSQL := "select id,sku,name,description from product where active = @active and category_id = @category_id" + lExpectedSQL := "select \"id\",\"sku\",\"name\",\"description\" from \"product\" where \"active\" = @active and \"category_id\" = @category_id" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.SelectWithValues.SQL") testutils.AssertEqualWithLabel(t, 2, len(lReturnedValues), "SQLBuilder.SelectWithValues.Values") testutils.AssertNamedArgEqualWithLabel(t, lReturnedValues[0], sql.NamedArg{Name: "active", Value: true}, "SQLBuilder.SelectWithValues.Values[0]") @@ -133,7 +133,7 @@ func TestSelectWithValues(t *testing.T) { } func TestSelectWithPagination(t *testing.T) { - lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil) + lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil, nil) lReturnedSQL := lSQLBuilder.SelectWithPagination("select * from product", 10, 5) lExpectedSQL := "select * from product LIMIT 10 OFFSET 40" @@ -141,7 +141,7 @@ func TestSelectWithPagination(t *testing.T) { } func TestInteroperateSQLCommandWithNamedArgs(t *testing.T) { - lSQLBuilder := internal.NewGenericSQLBuilder(nil, testutils.FakeInteroperateSQLCommandWithNamedArgsFunc) + lSQLBuilder := internal.NewGenericSQLBuilder(nil, nil, testutils.FakeInteroperateSQLCommandWithNamedArgsFunc) lReturnedSQL, lReturnedNamedArgs := lSQLBuilder.InteroperateSQLCommandWithNamedArgs("original command", sql.NamedArg{Name: "param1", Value: 1}) lExpectedSQL := "command has been modified" testutils.AssertEqualWithLabel(t, lExpectedSQL, lReturnedSQL, "SQLBuilder.InteroperateSQLCommandWithNamedArgs.SQL") @@ -150,13 +150,17 @@ func TestInteroperateSQLCommandWithNamedArgs(t *testing.T) { } func TestColumnDefinition(t *testing.T) { - lSQLBuilder := internal.NewGenericSQLBuilder(testutils.FakeColumnDefinitionFunc, nil) - lColumn, lError := schema.NewColumn(schema.NewColumnParams{Name: "column_name", Type: schema.Integer, Size: 0}) - testutils.AssertNilError(t, lError, "schema.NewColumn") - - lTableName, lError := schema.NewTableName("test_table") - testutils.AssertNilError(t, lError, "schema.NewTableName") + lSQLBuilder := internal.NewGenericSQLBuilder(testutils.FakeColumnDefinitionFunc, nil, nil) + lColumn := schema.NewColumn(schema.NewColumnParams{Name: "column_name", Type: schema.Integer, Size: 0}) + lTableName := "test_table" + lReturnedSQL := lSQLBuilder.AlterTableAddColumn(lTableName, lColumn) + testutils.AssertEqualWithLabel(t, "ALTER TABLE \"test_table\" ADD COLUMN fake;", lReturnedSQL, "SQLBuilder.ColumnDefinition") +} - lReturnedSQL := lSQLBuilder.AlterTableAddColumn(*lTableName, *lColumn) - testutils.AssertEqualWithLabel(t, "ALTER TABLE test_table ADD COLUMN fake;", lReturnedSQL, "SQLBuilder.ColumnDefinition") +func TestQuoteIdentifier(t *testing.T) { + lSQLBuilder := internal.NewGenericSQLBuilder(nil, testutils.FakeQuoteIdentifierFunc, nil) + lColumn := schema.NewColumn(schema.NewColumnParams{Name: "column_name", Type: schema.Integer, Size: 0}) + lTableName := "test_table" + lReturnedSQL := lSQLBuilder.AlterTableAddColumn(lTableName, lColumn) + testutils.AssertEqualWithLabel(t, "ALTER TABLE quoted_test_table ADD COLUMN quoted_column_name <>;", lReturnedSQL, "SQLBuilder.QuoteIdentifier") } diff --git a/internal/testutils/fake.go b/internal/testutils/fake.go index 44ce3d7..245bfc4 100644 --- a/internal/testutils/fake.go +++ b/internal/testutils/fake.go @@ -7,214 +7,151 @@ import ( "github.com/ordershift/ormshift/schema" ) -func FakeProductAttributeTable(t *testing.T) *schema.Table { - lProductAttributeTable, lError := schema.NewTable("product_attribute") - if !AssertNotNilResultAndNilError(t, lProductAttributeTable, lError, "schema.NewTable") { - return nil - } - lError = lProductAttributeTable.AddColumn(schema.NewColumnParams{ - Name: "product_id", - Type: schema.Integer, - PrimaryKey: true, - NotNull: true, - AutoIncrement: false, - }) - if !AssertNilError(t, lError, "ProductAttributeTable.AddColumn") { - return nil - } - lError = lProductAttributeTable.AddColumn(schema.NewColumnParams{ - Name: "attribute_id", - Type: schema.Integer, - PrimaryKey: true, - NotNull: true, - AutoIncrement: false, - }) - if !AssertNilError(t, lError, "ProductAttributeTable.AddColumn") { - return nil - } - lError = lProductAttributeTable.AddColumn(schema.NewColumnParams{ - Name: "value", - Type: schema.Varchar, - Size: 75, - PrimaryKey: false, - NotNull: false, - AutoIncrement: false, - }) - if !AssertNilError(t, lError, "ProductAttributeTable.AddColumn") { - return nil - } - lError = lProductAttributeTable.AddColumn(schema.NewColumnParams{ - Name: "position", - Type: schema.Integer, - PrimaryKey: false, - NotNull: false, - AutoIncrement: false, - }) - if !AssertNilError(t, lError, "ProductAttributeTable.AddColumn") { - return nil +func FakeProductAttributeTable(t *testing.T) schema.Table { + lProductAttributeTable := schema.NewTable("product_attribute") + lError := lProductAttributeTable.AddColumns( + schema.NewColumnParams{ + Name: "product_id", + Type: schema.Integer, + PrimaryKey: true, + NotNull: true, + }, + schema.NewColumnParams{ + Name: "attribute_id", + Type: schema.Integer, + PrimaryKey: true, + NotNull: true, + }, + schema.NewColumnParams{ + Name: "value", + Type: schema.Varchar, + Size: 75, + }, + schema.NewColumnParams{ + Name: "position", + Type: schema.Integer, + }, + ) + if !AssertNilError(t, lError, "ProductAttributeTable.AddColumns") { + panic(lError) } return lProductAttributeTable } -func FakeUserTable(t *testing.T) *schema.Table { - lUserTable, lError := schema.NewTable("user") - if !AssertNotNilResultAndNilError(t, lUserTable, lError, "schema.NewTable") { - return nil - } - lError = lUserTable.AddColumn(schema.NewColumnParams{ - Name: "id", - Type: schema.Integer, - PrimaryKey: true, - NotNull: true, - AutoIncrement: true, - }) - if !AssertNilError(t, lError, "UserTable.AddColumn") { - return nil - } - lError = lUserTable.AddColumn(schema.NewColumnParams{ - Name: "email", - Type: schema.Varchar, - Size: 80, - PrimaryKey: true, - NotNull: true, - AutoIncrement: false, - }) - if !AssertNilError(t, lError, "UserTable.AddColumn") { - return nil - } - lError = lUserTable.AddColumn(schema.NewColumnParams{ - Name: "name", - Type: schema.Varchar, - Size: 50, - PrimaryKey: false, - NotNull: true, - AutoIncrement: false, - }) - if !AssertNilError(t, lError, "UserTable.AddColumn") { - return nil - } - lError = lUserTable.AddColumn(schema.NewColumnParams{ - Name: "password_hash", - Type: schema.Varchar, - Size: 256, - PrimaryKey: false, - NotNull: false, - AutoIncrement: false, - }) - if !AssertNilError(t, lError, "UserTable.AddColumn") { - return nil - } - lError = lUserTable.AddColumn(schema.NewColumnParams{ - Name: "active", - Type: schema.Boolean, - PrimaryKey: false, - NotNull: false, - AutoIncrement: false, - }) - if !AssertNilError(t, lError, "UserTable.AddColumn") { - return nil - } - lError = lUserTable.AddColumn(schema.NewColumnParams{ - Name: "created_at", - Type: schema.DateTime, - PrimaryKey: false, - NotNull: false, - AutoIncrement: false, - }) - if !AssertNilError(t, lError, "UserTable.AddColumn") { - return nil - } - lError = lUserTable.AddColumn(schema.NewColumnParams{ - Name: "user_master", - Type: schema.Integer, - PrimaryKey: false, - NotNull: false, - AutoIncrement: false, - }) - if !AssertNilError(t, lError, "UserTable.AddColumn") { - return nil - } - lError = lUserTable.AddColumn(schema.NewColumnParams{ - Name: "master_user_id", - Type: schema.Integer, - PrimaryKey: false, - NotNull: false, - AutoIncrement: false, - }) - if !AssertNilError(t, lError, "UserTable.AddColumn") { - return nil - } - lError = lUserTable.AddColumn(schema.NewColumnParams{ - Name: "licence_price", - Type: schema.Monetary, - PrimaryKey: false, - NotNull: false, - AutoIncrement: false, - }) - if !AssertNilError(t, lError, "UserTable.AddColumn") { - return nil - } - lError = lUserTable.AddColumn(schema.NewColumnParams{ - Name: "relevance", - Type: schema.Decimal, - PrimaryKey: false, - NotNull: false, - AutoIncrement: false, - }) - if !AssertNilError(t, lError, "UserTable.AddColumn") { - return nil - } - lError = lUserTable.AddColumn(schema.NewColumnParams{ - Name: "photo", - Type: schema.Binary, - PrimaryKey: false, - NotNull: false, - AutoIncrement: false, - }) - if !AssertNilError(t, lError, "UserTable.AddColumn") { - return nil - } - lError = lUserTable.AddColumn(schema.NewColumnParams{ - Name: "any", - Type: schema.ColumnType(-1), - PrimaryKey: false, - NotNull: false, - AutoIncrement: false, - }) - if !AssertNilError(t, lError, "UserTable.AddColumn") { - return nil +func FakeUserTable(t *testing.T) schema.Table { + lUserTable := schema.NewTable("user") + lError := lUserTable.AddColumns( + schema.NewColumnParams{ + Name: "id", + Type: schema.Integer, + PrimaryKey: true, + NotNull: true, + AutoIncrement: true, + }, + schema.NewColumnParams{ + Name: "email", + Type: schema.Varchar, + Size: 80, + PrimaryKey: true, + NotNull: true, + AutoIncrement: false, + }, + schema.NewColumnParams{ + Name: "name", + Type: schema.Varchar, + Size: 50, + PrimaryKey: false, + NotNull: true, + AutoIncrement: false, + }, + schema.NewColumnParams{ + Name: "password_hash", + Type: schema.Varchar, + Size: 256, + PrimaryKey: false, + NotNull: false, + AutoIncrement: false, + }, + schema.NewColumnParams{ + Name: "active", + Type: schema.Boolean, + PrimaryKey: false, + NotNull: false, + AutoIncrement: false, + }, + schema.NewColumnParams{ + Name: "created_at", + Type: schema.DateTime, + PrimaryKey: false, + NotNull: false, + AutoIncrement: false, + }, + schema.NewColumnParams{ + Name: "user_master", + Type: schema.Integer, + PrimaryKey: false, + NotNull: false, + AutoIncrement: false, + }, + schema.NewColumnParams{ + Name: "master_user_id", + Type: schema.Integer, + PrimaryKey: false, + NotNull: false, + AutoIncrement: false, + }, + schema.NewColumnParams{ + Name: "licence_price", + Type: schema.Monetary, + PrimaryKey: false, + NotNull: false, + AutoIncrement: false, + }, + schema.NewColumnParams{ + Name: "relevance", + Type: schema.Decimal, + PrimaryKey: false, + NotNull: false, + AutoIncrement: false, + }, + schema.NewColumnParams{ + Name: "photo", + Type: schema.Binary, + PrimaryKey: false, + NotNull: false, + AutoIncrement: false, + }, + schema.NewColumnParams{ + Name: "any", + Type: schema.ColumnType(-1), + PrimaryKey: false, + NotNull: false, + AutoIncrement: false, + }, + ) + if !AssertNilError(t, lError, "UserTable.AddColumns") { + panic(lError) } return lUserTable } -func FakeUserTableName(t *testing.T) *schema.TableName { - lUserTableName, lError := schema.NewTableName("user") - if !AssertNotNilResultAndNilError(t, lUserTableName, lError, "schema.NewTableName") { - return nil - } - return lUserTableName +func FakeUserTableName(t *testing.T) string { + return "user" } -func FakeUpdatedAtColumn(t *testing.T) *schema.Column { - lUpdatedAtColumn, lError := schema.NewColumn(schema.NewColumnParams{ +func FakeUpdatedAtColumn(t *testing.T) schema.Column { + lUpdatedAtColumn := schema.NewColumn(schema.NewColumnParams{ Name: "updated_at", Type: schema.DateTime, PrimaryKey: false, NotNull: false, AutoIncrement: false, }) - if !AssertNotNilResultAndNilError(t, lUpdatedAtColumn, lError, "schema.NewColumn") { - return nil - } return lUpdatedAtColumn } -func FakeUpdatedAtColumnName(t *testing.T) *schema.ColumnName { - lUpdatedAtColumnName, lError := schema.NewColumnName("updated_at") - if !AssertNotNilResultAndNilError(t, lUpdatedAtColumnName, lError, "schema.NewColumnName") { - return nil - } - return lUpdatedAtColumnName +func FakeUpdatedAtColumnName(t *testing.T) string { + return "updated_at" } func FakeInteroperateSQLCommandWithNamedArgsFunc(command string, namedArgs ...sql.NamedArg) (string, []any) { @@ -228,3 +165,7 @@ func FakeInteroperateSQLCommandWithNamedArgsFunc(command string, namedArgs ...sq func FakeColumnDefinitionFunc(column schema.Column) string { return "fake" } + +func FakeQuoteIdentifierFunc(identifier string) string { + return "quoted_" + identifier +} diff --git a/internal/testutils/migrations.go b/internal/testutils/migrations.go index 8382f9d..564283c 100644 --- a/internal/testutils/migrations.go +++ b/internal/testutils/migrations.go @@ -11,131 +11,100 @@ import ( type M001_Create_Table_User struct{} func (m M001_Create_Table_User) Up(pMigrator *migrations.Migrator) error { - lUserTable, lError := schema.NewTable("user") - if lError != nil { - return lError - } + lUserTable := schema.NewTable("user") if pMigrator.Database().DBSchema().HasTable(lUserTable.Name()) { return nil } - columns := []schema.NewColumnParams{ - { + + lError := lUserTable.AddColumns( + schema.NewColumnParams{ Name: "id", Type: schema.Integer, AutoIncrement: true, PrimaryKey: true, NotNull: true, }, - { + schema.NewColumnParams{ Name: "name", Type: schema.Varchar, Size: 50, PrimaryKey: false, NotNull: false, }, - { + schema.NewColumnParams{ Name: "email", Type: schema.Varchar, Size: 120, PrimaryKey: false, NotNull: false, }, - { + schema.NewColumnParams{ Name: "active", Type: schema.Boolean, PrimaryKey: false, NotNull: false, }, - { + schema.NewColumnParams{ Name: "ammount", Type: schema.Monetary, PrimaryKey: false, NotNull: false, }, - { + schema.NewColumnParams{ Name: "percent", Type: schema.Decimal, PrimaryKey: false, NotNull: false, }, - { + schema.NewColumnParams{ Name: "photo", Type: schema.Binary, PrimaryKey: false, NotNull: false, }, - } - - for _, col := range columns { - if err := lUserTable.AddColumn(col); err != nil { - return err - } - } - - _, lError = pMigrator.Database().SQLExecutor().Exec(pMigrator.Database().SQLBuilder().CreateTable(*lUserTable)) + ) if lError != nil { return lError } - return nil + + _, lError = pMigrator.Database().SQLExecutor().Exec(pMigrator.Database().SQLBuilder().CreateTable(lUserTable)) + return lError } func (m M001_Create_Table_User) Down(pMigrator *migrations.Migrator) error { - lUserTableName, lError := schema.NewTableName("user") - if lError != nil { - return lError - } - if !pMigrator.Database().DBSchema().HasTable(*lUserTableName) { + lUserTableName := "user" + if !pMigrator.Database().DBSchema().HasTable(lUserTableName) { return nil } - _, lError = pMigrator.Database().SQLExecutor().Exec(pMigrator.Database().SQLBuilder().DropTable(*lUserTableName)) - if lError != nil { - return lError - } - return nil + + _, lError := pMigrator.Database().SQLExecutor().Exec(pMigrator.Database().SQLBuilder().DropTable(lUserTableName)) + return lError } // M002_Alter_Table_User_Add_Column_UpdatedAt adds the "updated_at" column to the "user" table. type M002_Alter_Table_User_Add_Column_UpdatedAt struct{} func (m M002_Alter_Table_User_Add_Column_UpdatedAt) Up(pMigrator *migrations.Migrator) error { - lUserTableName, lError := schema.NewTableName("user") - if lError != nil { - return lError - } - lUpdatedAtColumn, lError := schema.NewColumn(schema.NewColumnParams{ + lUserTableName := "user" + lUpdatedAtColumn := schema.NewColumn(schema.NewColumnParams{ Name: "updated_at", Type: schema.DateTime, }) - if lError != nil { - return lError - } - if pMigrator.Database().DBSchema().HasColumn(*lUserTableName, lUpdatedAtColumn.Name()) { + if pMigrator.Database().DBSchema().HasColumn(lUserTableName, lUpdatedAtColumn.Name()) { return nil } - _, lError = pMigrator.Database().SQLExecutor().Exec(pMigrator.Database().SQLBuilder().AlterTableAddColumn(*lUserTableName, *lUpdatedAtColumn)) - if lError != nil { - return lError - } - return nil + _, lError := pMigrator.Database().SQLExecutor().Exec(pMigrator.Database().SQLBuilder().AlterTableAddColumn(lUserTableName, lUpdatedAtColumn)) + return lError } func (m M002_Alter_Table_User_Add_Column_UpdatedAt) Down(pMigrator *migrations.Migrator) error { - lUserTableName, lError := schema.NewTableName("user") - if lError != nil { - return lError - } - lUpdatedAtColumnName, lError := schema.NewColumnName("updated_at") - if lError != nil { - return lError - } - if !pMigrator.Database().DBSchema().HasColumn(*lUserTableName, *lUpdatedAtColumnName) { + lUserTableName := "user" + lUpdatedAtColumnName := "updated_at" + if !pMigrator.Database().DBSchema().HasColumn(lUserTableName, lUpdatedAtColumnName) { return nil } - _, lError = pMigrator.Database().SQLExecutor().Exec(pMigrator.Database().SQLBuilder().AlterTableDropColumn(*lUserTableName, *lUpdatedAtColumnName)) - if lError != nil { - return lError - } - return nil + _, lError := pMigrator.Database().SQLExecutor().Exec(pMigrator.Database().SQLBuilder().AlterTableDropColumn(lUserTableName, lUpdatedAtColumnName)) + return lError } // M003_Bad_Migration_Fails_To_Apply is a migration that always fails to apply. diff --git a/migrations/config.go b/migrations/config.go index 917ca9d..4adb1ef 100644 --- a/migrations/config.go +++ b/migrations/config.go @@ -7,41 +7,39 @@ type MigratorConfig struct { appliedAtColumn string } -func NewMigratorConfig() MigratorConfig { +func NewMigratorConfig() *MigratorConfig { lConfig := MigratorConfig{ tableName: "__ormshift_migrations", migrationNameColumn: "name", migrationNameMaxLength: 250, appliedAtColumn: "applied_at", } - return lConfig + return &lConfig } -func (mc MigratorConfig) WithTableName(pTableName string) MigratorConfig { +func (mc *MigratorConfig) WithTableName(pTableName string) *MigratorConfig { mc.tableName = pTableName return mc } - -func (mc MigratorConfig) WithColumnNames(pMigrationNameColumn, pAppliedAtColumn string) MigratorConfig { +func (mc *MigratorConfig) WithColumnNames(pMigrationNameColumn, pAppliedAtColumn string) *MigratorConfig { mc.migrationNameColumn = pMigrationNameColumn mc.appliedAtColumn = pAppliedAtColumn return mc } - -func (mc MigratorConfig) WithMigrationNameMaxLength(pMaxLength uint) MigratorConfig { +func (mc *MigratorConfig) WithMigrationNameMaxLength(pMaxLength uint) *MigratorConfig { mc.migrationNameMaxLength = pMaxLength return mc } -func (mc MigratorConfig) TableName() string { +func (mc *MigratorConfig) TableName() string { return mc.tableName } -func (mc MigratorConfig) MigrationNameColumn() string { +func (mc *MigratorConfig) MigrationNameColumn() string { return mc.migrationNameColumn } -func (mc MigratorConfig) MigrationNameMaxLength() uint { +func (mc *MigratorConfig) MigrationNameMaxLength() uint { return mc.migrationNameMaxLength } -func (mc MigratorConfig) AppliedAtColumn() string { +func (mc *MigratorConfig) AppliedAtColumn() string { return mc.appliedAtColumn } diff --git a/migrations/migrations.go b/migrations/migrations.go index 1e92128..d14fcd0 100644 --- a/migrations/migrations.go +++ b/migrations/migrations.go @@ -7,7 +7,7 @@ type Migration interface { Down(pMigrator *Migrator) error } -func Migrate(pDatabase *ormshift.Database, pConfig MigratorConfig, pMigrations ...Migration) (*Migrator, error) { +func Migrate(pDatabase *ormshift.Database, pConfig *MigratorConfig, pMigrations ...Migration) (*Migrator, error) { lMigrator, lError := NewMigrator(pDatabase, pConfig) if lError != nil { return nil, lError diff --git a/migrations/migrations_test.go b/migrations/migrations_test.go index e3422be..0b7ef0d 100644 --- a/migrations/migrations_test.go +++ b/migrations/migrations_test.go @@ -7,7 +7,6 @@ import ( "github.com/ordershift/ormshift/dialects/sqlite" "github.com/ordershift/ormshift/internal/testutils" "github.com/ordershift/ormshift/migrations" - "github.com/ordershift/ormshift/schema" ) func TestMigrate(t *testing.T) { @@ -26,16 +25,10 @@ func TestMigrate(t *testing.T) { if !testutils.AssertNotNilResultAndNilError(t, lMigrator, lError, "migrations.Migrate") { return } - lUserTableName, lError := schema.NewTableName("user") - if !testutils.AssertNilError(t, lError, "migrations.NewTableName") { - return - } - lUpdatedAtColumnName, lError := schema.NewColumnName("updated_at") - if !testutils.AssertNilError(t, lError, "migrations.NewColumnName") { - return - } - testutils.AssertEqualWithLabel(t, true, lDB.DBSchema().HasColumn(*lUserTableName, *lUpdatedAtColumnName), "Migrator.DBSchema.HasColumn[user.updated_at]") - testutils.AssertEqualWithLabel(t, 2, len(lMigrator.AppliedMigrations()), "len(Migrator.AppliedMigrationNames)") + lUserTableName := "user" + lUpdatedAtColumnName := "updated_at" + testutils.AssertEqualWithLabel(t, true, lDB.DBSchema().HasColumn(lUserTableName, lUpdatedAtColumnName), "Migrator.DBSchema.HasColumn[user.updated_at]") + testutils.AssertEqualWithLabel(t, 2, len(lMigrator.AppliedMigrations()), "len(Migrator.AppliedMigrations)") } func TestMigrateTwice(t *testing.T) { @@ -65,15 +58,9 @@ func TestMigrateTwice(t *testing.T) { return } - lUserTableName, lError := schema.NewTableName("user") - if !testutils.AssertNilError(t, lError, "migrations.NewTableName") { - return - } - lUpdatedAtColumnName, lError := schema.NewColumnName("updated_at") - if !testutils.AssertNilError(t, lError, "migrations.NewColumnName") { - return - } - testutils.AssertEqualWithLabel(t, true, lDB.DBSchema().HasColumn(*lUserTableName, *lUpdatedAtColumnName), "Migrator.DBSchema.HasColumn[user.updated_at]") + lUserTableName := "user" + lUpdatedAtColumnName := "updated_at" + testutils.AssertEqualWithLabel(t, true, lDB.DBSchema().HasColumn(lUserTableName, lUpdatedAtColumnName), "Migrator.DBSchema.HasColumn[user.updated_at]") testutils.AssertEqualWithLabel(t, 2, len(lMigrator.AppliedMigrations()), "len(Migrator.AppliedMigrations)") } diff --git a/migrations/migrator.go b/migrations/migrator.go index e9a13cf..7a83934 100644 --- a/migrations/migrator.go +++ b/migrations/migrator.go @@ -11,17 +11,20 @@ import ( type Migrator struct { database *ormshift.Database - config MigratorConfig + config *MigratorConfig migrations []Migration appliedMigrations map[string]bool } -func NewMigrator(pDatabase *ormshift.Database, pConfig MigratorConfig) (*Migrator, error) { +func NewMigrator(pDatabase *ormshift.Database, pConfig *MigratorConfig) (*Migrator, error) { if pDatabase == nil { return nil, fmt.Errorf("database cannot be nil") } + if pConfig == nil { + return nil, fmt.Errorf("migrator config cannot be nil") + } - lAppliedMigrationNames, lError := getAppliedMigrationNames(*pDatabase, pConfig) + lAppliedMigrationNames, lError := getAppliedMigrationNames(pDatabase, pConfig) if lError != nil { return nil, fmt.Errorf("failed to get applied migration names: %w", lError) } @@ -80,16 +83,15 @@ func (m *Migrator) RevertLastAppliedMigration() error { return nil } -func (m Migrator) Database() *ormshift.Database { +func (m *Migrator) Database() *ormshift.Database { return m.database } -func (m Migrator) Migrations() []Migration { +func (m *Migrator) Migrations() []Migration { return m.migrations } -func (m Migrator) AppliedMigrations() []Migration { - +func (m *Migrator) AppliedMigrations() []Migration { lMigrations := []Migration{} for _, migration := range m.Migrations() { name := reflect.TypeOf(migration).Name() @@ -100,12 +102,12 @@ func (m Migrator) AppliedMigrations() []Migration { return lMigrations } -func (m Migrator) isApplied(pMigrationName string) bool { +func (m *Migrator) isApplied(pMigrationName string) bool { _, exists := m.appliedMigrations[pMigrationName] return exists } -func (m Migrator) recordAppliedMigration(pMigrationName string) error { +func (m *Migrator) recordAppliedMigration(pMigrationName string) error { q, p := m.database.SQLBuilder().InsertWithValues( m.config.tableName, ormshift.ColumnsValues{ @@ -117,7 +119,7 @@ func (m Migrator) recordAppliedMigration(pMigrationName string) error { return lError } -func (m Migrator) deleteAppliedMigration(pMigrationName string) error { +func (m *Migrator) deleteAppliedMigration(pMigrationName string) error { q, p := m.database.SQLBuilder().DeleteWithValues( m.config.tableName, ormshift.ColumnsValues{ @@ -128,7 +130,7 @@ func (m Migrator) deleteAppliedMigration(pMigrationName string) error { return lError } -func getAppliedMigrationNames(pDatabase ormshift.Database, pConfig MigratorConfig) (rMigrationNames []string, rError error) { +func getAppliedMigrationNames(pDatabase *ormshift.Database, pConfig *MigratorConfig) (rMigrationNames []string, rError error) { rError = ensureMigrationsTableExists(pDatabase, pConfig) if rError != nil { return @@ -137,9 +139,9 @@ func getAppliedMigrationNames(pDatabase ormshift.Database, pConfig MigratorConfi q, p := pDatabase.SQLBuilder().InteroperateSQLCommandWithNamedArgs( fmt.Sprintf( "select %s from %s order by %s", - pConfig.migrationNameColumn, - pConfig.tableName, - pConfig.migrationNameColumn, + pDatabase.SQLBuilder().QuoteIdentifier(pConfig.migrationNameColumn), + pDatabase.SQLBuilder().QuoteIdentifier(pConfig.tableName), + pDatabase.SQLBuilder().QuoteIdentifier(pConfig.migrationNameColumn), ), ) lMigrationsRows, rError := pDatabase.SQLExecutor().Query(q, p...) @@ -162,34 +164,29 @@ func getAppliedMigrationNames(pDatabase ormshift.Database, pConfig MigratorConfi return } -func ensureMigrationsTableExists(pDatabase ormshift.Database, pConfig MigratorConfig) error { - lMigrationsTable, lError := schema.NewTable(pConfig.tableName) +func ensureMigrationsTableExists(pDatabase *ormshift.Database, pConfig *MigratorConfig) error { + lMigrationsTable := schema.NewTable(pConfig.TableName()) + if pDatabase.DBSchema().HasTable(lMigrationsTable.Name()) { + return nil + } + lError := lMigrationsTable.AddColumns( + schema.NewColumnParams{ + Name: pConfig.MigrationNameColumn(), + Type: schema.Varchar, + Size: pConfig.MigrationNameMaxLength(), + PrimaryKey: true, + NotNull: true, + }, + schema.NewColumnParams{ + Name: pConfig.AppliedAtColumn(), + Type: schema.DateTime, + NotNull: true, + }, + ) if lError != nil { return lError } - if !pDatabase.DBSchema().HasTable(lMigrationsTable.Name()) { - columns := []schema.NewColumnParams{ - { - Name: pConfig.migrationNameColumn, - Type: schema.Varchar, - Size: pConfig.migrationNameMaxLength, - PrimaryKey: true, - NotNull: true, - }, - { - Name: pConfig.appliedAtColumn, - Type: schema.DateTime, - NotNull: true, - }, - } - for _, col := range columns { - if err := lMigrationsTable.AddColumn(col); err != nil { - return err - } - } - - _, lError = pDatabase.SQLExecutor().Exec(pDatabase.SQLBuilder().CreateTable(*lMigrationsTable)) // NOSONAR go:S2077 - Dynamic SQL is controlled and sanitized internally - } + _, lError = pDatabase.SQLExecutor().Exec(pDatabase.SQLBuilder().CreateTable(lMigrationsTable)) return lError } diff --git a/migrations/migrator_test.go b/migrations/migrator_test.go index 9e207ba..598b033 100644 --- a/migrations/migrator_test.go +++ b/migrations/migrator_test.go @@ -8,7 +8,6 @@ import ( "github.com/ordershift/ormshift/dialects/sqlite" "github.com/ordershift/ormshift/internal/testutils" "github.com/ordershift/ormshift/migrations" - "github.com/ordershift/ormshift/schema" ) func TestNewMigratorWhenDatabaseIsNil(t *testing.T) { @@ -17,6 +16,18 @@ func TestNewMigratorWhenDatabaseIsNil(t *testing.T) { testutils.AssertErrorMessage(t, "database cannot be nil", lError, "migrations.NewMigrator[database=nil]") } +func TestNewMigratorWhenConfigIsNil(t *testing.T) { + lDB, lError := ormshift.OpenDatabase(sqlite.Driver(), ormshift.ConnectionParams{InMemory: true}) + if !testutils.AssertNotNilResultAndNilError(t, lDB, lError, "ormshift.OpenDatabase") { + return + } + defer func() { _ = lDB.Close() }() + + lMigrator, lError := migrations.NewMigrator(lDB, nil) + testutils.AssertNilResultAndNotNilError(t, lMigrator, lError, "migrations.NewMigrator[config=nil]") + testutils.AssertErrorMessage(t, "migrator config cannot be nil", lError, "migrations.NewMigrator[config=nil]") +} + func TestNewMigratorWhenDatabaseIsInvalid(t *testing.T) { lDriver := testutils.NewFakeDriverInvalidConnectionString(postgresql.Driver()) lDB, lError := ormshift.OpenDatabase(lDriver, ormshift.ConnectionParams{}) @@ -69,21 +80,16 @@ func TestRevertLastAppliedMigration(t *testing.T) { return } - lUserTableName, lError := schema.NewTableName("user") - if !testutils.AssertNilError(t, lError, "migrations.NewTableName") { - return - } - testutils.AssertEqualWithLabel(t, true, lDB.DBSchema().HasTable(*lUserTableName), "Migrator.DBSchema.HasTable[user]") + lUserTableName := "user" + testutils.AssertEqualWithLabel(t, true, lDB.DBSchema().HasTable(lUserTableName), "Migrator.DBSchema.HasTable[user]") lError = lMigrator.RevertLastAppliedMigration() if !testutils.AssertNilError(t, lError, "Migrator.RevertLastAppliedMigration") { return } - lUpdatedAtColumnName, lError := schema.NewColumnName("updated_at") - if !testutils.AssertNilError(t, lError, "migrations.NewColumnName") { - return - } - testutils.AssertEqualWithLabel(t, false, lDB.DBSchema().HasColumn(*lUserTableName, *lUpdatedAtColumnName), "Migrator.DBSchema.HasColumn[user.updated_at]") + + lUpdatedAtColumnName := "updated_at" + testutils.AssertEqualWithLabel(t, false, lDB.DBSchema().HasColumn(lUserTableName, lUpdatedAtColumnName), "Migrator.DBSchema.HasColumn[user.updated_at]") } func TestRevertLastAppliedMigrationWhenNoMigrationsApplied(t *testing.T) { diff --git a/schema/column.go b/schema/column.go index 030fdee..7eed88e 100644 --- a/schema/column.go +++ b/schema/column.go @@ -1,27 +1,5 @@ package schema -import ( - "fmt" - "regexp" -) - -var regexValidColumnName = regexp.MustCompile(`^[A-Za-z][A-Za-z0-9_]*$`) - -type ColumnName struct { - columnName string -} - -func NewColumnName(pName string) (*ColumnName, error) { - if !regexValidColumnName.MatchString(pName) { - return nil, fmt.Errorf("invalid column name: %q", pName) - } - return &ColumnName{pName}, nil -} - -func (tn ColumnName) String() string { - return tn.columnName -} - type ColumnType int const ( @@ -44,49 +22,45 @@ type NewColumnParams struct { } type Column struct { - name ColumnName - columnType ColumnType - size uint - pk bool - notNull bool - autoInc bool + name string + columnType ColumnType + size uint + primaryKey bool + notNull bool + autoIncrement bool } -func NewColumn(pParams NewColumnParams) (*Column, error) { - lColumnName, lError := NewColumnName(pParams.Name) - if lError != nil { - return nil, lError +func NewColumn(pParams NewColumnParams) Column { + return Column{ + name: pParams.Name, + columnType: pParams.Type, + size: pParams.Size, + primaryKey: pParams.PrimaryKey, + notNull: pParams.NotNull, + autoIncrement: pParams.AutoIncrement, } - return &Column{ - name: *lColumnName, - columnType: pParams.Type, - size: pParams.Size, - pk: pParams.PrimaryKey, - notNull: pParams.NotNull, - autoInc: pParams.AutoIncrement, - }, nil } -func (c Column) Name() ColumnName { +func (c *Column) Name() string { return c.name } -func (c Column) Type() ColumnType { +func (c *Column) Type() ColumnType { return c.columnType } -func (c Column) Size() uint { +func (c *Column) Size() uint { return c.size } -func (c Column) PrimaryKey() bool { - return c.pk +func (c *Column) PrimaryKey() bool { + return c.primaryKey } -func (c Column) NotNull() bool { +func (c *Column) NotNull() bool { return c.notNull } -func (c Column) AutoIncrement() bool { - return c.autoInc +func (c *Column) AutoIncrement() bool { + return c.autoIncrement } diff --git a/schema/column_test.go b/schema/column_test.go index f5e5284..0a634a0 100644 --- a/schema/column_test.go +++ b/schema/column_test.go @@ -8,11 +8,9 @@ import ( ) func TestColumn(t *testing.T) { - lColumn, lError := schema.NewColumn(schema.NewColumnParams{Name: "id", Type: schema.Integer, NotNull: true, PrimaryKey: true, AutoIncrement: true}) - if !testutils.AssertNilError(t, lError, "schema.NewColumn") { - return - } - testutils.AssertEqualWithLabel(t, "id", lColumn.Name().String(), "Column.Name") + lColumn := schema.NewColumn(schema.NewColumnParams{Name: "id", Type: schema.Integer, NotNull: true, PrimaryKey: true, AutoIncrement: true}) + + testutils.AssertEqualWithLabel(t, "id", lColumn.Name(), "Column.Name") testutils.AssertEqualWithLabel(t, schema.Integer, lColumn.Type(), "Column.Type") testutils.AssertEqualWithLabel(t, uint(0), lColumn.Size(), "Column.Size") testutils.AssertEqualWithLabel(t, true, lColumn.PrimaryKey(), "Column.IsPrimaryKey") diff --git a/schema/schema.go b/schema/schema.go index 87857c3..4c48360 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -3,34 +3,44 @@ package schema import ( "database/sql" "errors" - "fmt" "slices" "strings" ) type DBSchema struct { - db *sql.DB - tableNamesQuery string + db *sql.DB + tableNamesQuery string + columnTypesQueryFunc ColumnTypesQueryFunc } -func NewDBSchema(pDB *sql.DB, pTableNamesQuery string) (*DBSchema, error) { +type ColumnTypesQueryFunc func(pTableName string) string + +func NewDBSchema( + pDB *sql.DB, + pTableNamesQuery string, + pColumnTypesQueryFunc ColumnTypesQueryFunc, +) (*DBSchema, error) { if pDB == nil { return nil, errors.New("sql.DB cannot be nil") } - return &DBSchema{db: pDB, tableNamesQuery: pTableNamesQuery}, nil + return &DBSchema{ + db: pDB, + tableNamesQuery: pTableNamesQuery, + columnTypesQueryFunc: pColumnTypesQueryFunc, + }, nil } -func (s DBSchema) HasTable(pTableName TableName) bool { +func (s *DBSchema) HasTable(pTableName string) bool { lTables, lError := s.fetchTableNames() if lError != nil { return false } return slices.ContainsFunc(lTables, func(t string) bool { - return strings.EqualFold(t, pTableName.String()) + return strings.EqualFold(t, pTableName) }) } -func (s DBSchema) fetchTableNames() (rTableNames []string, rError error) { +func (s *DBSchema) fetchTableNames() (rTableNames []string, rError error) { lRows, rError := s.db.Query(s.tableNamesQuery) if rError != nil { return @@ -51,18 +61,18 @@ func (s DBSchema) fetchTableNames() (rTableNames []string, rError error) { return } -func (s DBSchema) HasColumn(pTableName TableName, pColumnName ColumnName) bool { +func (s *DBSchema) HasColumn(pTableName, pColumnName string) bool { lColumnTypes, lError := s.fetchColumnTypes(pTableName) if lError != nil { return false } return slices.ContainsFunc(lColumnTypes, func(ct *sql.ColumnType) bool { - return strings.EqualFold(ct.Name(), pColumnName.String()) + return strings.EqualFold(ct.Name(), pColumnName) }) } -func (s DBSchema) fetchColumnTypes(pTableName TableName) (rColumnTypes []*sql.ColumnType, rError error) { - lRows, rError := s.db.Query(fmt.Sprintf("SELECT * FROM %s WHERE 1=0", pTableName.String())) // NOSONAR go:S2077 - Dynamic SQL is controlled and sanitized internally +func (s *DBSchema) fetchColumnTypes(pTableName string) (rColumnTypes []*sql.ColumnType, rError error) { + lRows, rError := s.db.Query(s.columnTypesQueryFunc(pTableName)) if rError != nil { return } diff --git a/schema/schema_test.go b/schema/schema_test.go index fc3bdfe..7985eb9 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -1,6 +1,7 @@ package schema_test import ( + "fmt" "testing" "github.com/ordershift/ormshift" @@ -9,6 +10,10 @@ import ( "github.com/ordershift/ormshift/schema" ) +func testColumnTypesQueryFunc(pTableName string) string { + return fmt.Sprintf("SELECT * FROM %s WHERE 1=0", pTableName) +} + func TestNewDBSchema(t *testing.T) { lDB, lError := ormshift.OpenDatabase(sqlite.Driver(), ormshift.ConnectionParams{InMemory: true}) if !testutils.AssertNotNilResultAndNilError(t, lDB, lError, "ormshift.OpenDatabase") { @@ -16,14 +21,14 @@ func TestNewDBSchema(t *testing.T) { } defer func() { _ = lDB.Close() }() - lDBSchema, lError := schema.NewDBSchema(lDB.DB(), "query") + lDBSchema, lError := schema.NewDBSchema(lDB.DB(), "query", testColumnTypesQueryFunc) if !testutils.AssertNotNilResultAndNilError(t, lDBSchema, lError, "schema.NewDBSchema") { return } } func TestNewDBSchemaFailsWhenDBIsNil(t *testing.T) { - lDBSchema, lError := schema.NewDBSchema(nil, "query") + lDBSchema, lError := schema.NewDBSchema(nil, "query", testColumnTypesQueryFunc) if !testutils.AssertNilResultAndNotNilError(t, lDBSchema, lError, "schema.NewDBSchema") { return } @@ -38,11 +43,8 @@ func TestHasColumn(t *testing.T) { defer func() { _ = lDB.Close() }() lProductAttributeTable := testutils.FakeProductAttributeTable(t) - if lProductAttributeTable == nil { - return - } - _, lError = lDB.SQLExecutor().Exec(sqlite.Driver().SQLBuilder().CreateTable(*lProductAttributeTable)) + _, lError = lDB.SQLExecutor().Exec(sqlite.Driver().SQLBuilder().CreateTable(lProductAttributeTable)) if !testutils.AssertNilError(t, lError, "DB.Exec") { return } @@ -52,17 +54,10 @@ func TestHasColumn(t *testing.T) { for _, lColumn := range lProductAttributeTable.Columns() { testutils.AssertEqualWithLabel(t, true, lDBSchema.HasColumn(lProductAttributeTable.Name(), lColumn.Name()), "DBSchema.HasColumn") } - lAnyTableName, lError := schema.NewTableName("any_table") - if !testutils.AssertNotNilResultAndNilError(t, lAnyTableName, lError, "ormshift.NewTableName") { - return - } - testutils.AssertEqualWithLabel(t, false, lDBSchema.HasTable(*lAnyTableName), "DBSchema.HasTable") - lAnyColumnName, lError := schema.NewColumnName("any_col") - if !testutils.AssertNotNilResultAndNilError(t, lAnyColumnName, lError, "ormshift.NewTableName") { - return - } - testutils.AssertEqualWithLabel(t, false, lDBSchema.HasColumn(lProductAttributeTable.Name(), *lAnyColumnName), "DBSchema.HasColumn") - testutils.AssertEqualWithLabel(t, false, lDBSchema.HasColumn(*lAnyTableName, *lAnyColumnName), "DBSchema.HasColumn") + lAnyTableName := "any_table" + lAnyColumnName := "any_col" + testutils.AssertEqualWithLabel(t, false, lDBSchema.HasColumn(lProductAttributeTable.Name(), lAnyColumnName), "DBSchema.HasColumn") + testutils.AssertEqualWithLabel(t, false, lDBSchema.HasColumn(lAnyTableName, lAnyColumnName), "DBSchema.HasColumn") } func TestHasTableReturnsFalseWhenDatabaseIsInvalid(t *testing.T) { @@ -73,11 +68,8 @@ func TestHasTableReturnsFalseWhenDatabaseIsInvalid(t *testing.T) { defer func() { _ = lDB.Close() }() lProductAttributeTable := testutils.FakeProductAttributeTable(t) - if lProductAttributeTable == nil { - return - } - _, lError = lDB.SQLExecutor().Exec(sqlite.Driver().SQLBuilder().CreateTable(*lProductAttributeTable)) + _, lError = lDB.SQLExecutor().Exec(sqlite.Driver().SQLBuilder().CreateTable(lProductAttributeTable)) if !testutils.AssertNilError(t, lError, "DB.Exec") { return } diff --git a/schema/table.go b/schema/table.go index 3f4eca6..c1a5ce2 100644 --- a/schema/table.go +++ b/schema/table.go @@ -2,63 +2,40 @@ package schema import ( "fmt" - "regexp" "slices" "strings" ) -var regexValidTableName = regexp.MustCompile(`^([A-Za-z_][A-Za-z0-9_]*\.)*[A-Za-z_][A-Za-z0-9_]*$`) - -type TableName struct { - tableName string -} - -func NewTableName(pName string) (*TableName, error) { - if !regexValidTableName.MatchString(pName) { - return nil, fmt.Errorf("invalid table name: %q", pName) - } - return &TableName{pName}, nil -} - -func (tn TableName) String() string { - return tn.tableName -} - type Table struct { - name TableName + name string columns []Column } -func NewTable(pName string) (*Table, error) { - lTableName, lError := NewTableName(pName) - if lError != nil { - return nil, lError - } - return &Table{ - name: *lTableName, +func NewTable(pName string) Table { + return Table{ + name: pName, columns: []Column{}, - }, nil -} - -func (t *Table) AddColumn(pParams NewColumnParams) error { - lColumn, lError := NewColumn(pParams) - if lError != nil { - return lError } - lColumnAlreadyExists := slices.ContainsFunc(t.columns, func(c Column) bool { - return strings.EqualFold(lColumn.Name().String(), c.Name().String()) - }) - if lColumnAlreadyExists { - return fmt.Errorf("column %q already exists in table %q", lColumn.Name().String(), t.name) - } - t.columns = append(t.columns, *lColumn) - return nil } -func (t Table) Name() TableName { +func (t *Table) Name() string { return t.name } -func (t Table) Columns() []Column { +func (t *Table) Columns() []Column { return t.columns } + +func (t *Table) AddColumns(pParams ...NewColumnParams) error { + for _, lColParams := range pParams { + lColumn := NewColumn(lColParams) + lColumnAlreadyExists := slices.ContainsFunc(t.columns, func(c Column) bool { + return strings.EqualFold(lColumn.Name(), c.Name()) + }) + if lColumnAlreadyExists { + return fmt.Errorf("column %q already exists in table %q", lColumn.Name(), t.Name()) + } + t.columns = append(t.columns, lColumn) + } + return nil +} diff --git a/schema/table_test.go b/schema/table_test.go index 7ed323a..028c5bd 100644 --- a/schema/table_test.go +++ b/schema/table_test.go @@ -8,42 +8,14 @@ import ( "github.com/ordershift/ormshift/schema" ) -func TestNewTableFailsWithInvalidName(t *testing.T) { - lInvalidTableName := "123456-table" - lTable, lError := schema.NewTable(lInvalidTableName) - if !testutils.AssertNilResultAndNotNilError(t, lTable, lError, "schema.NewTable") { - return - } - testutils.AssertErrorMessage(t, fmt.Sprintf("invalid table name: %q", lInvalidTableName), lError, "schema.NewTable") -} - -func TestAddColumnFailsWithInvalidName(t *testing.T) { - lProductAttributeTable := testutils.FakeProductAttributeTable(t) - if lProductAttributeTable == nil { - return - } - lInvalidColumnName := "123456-column" - lError := lProductAttributeTable.AddColumn(schema.NewColumnParams{ - Name: lInvalidColumnName, - Type: schema.Integer, - }) - if !testutils.AssertNotNilError(t, lError, "Table.AddColumn") { - return - } - testutils.AssertErrorMessage(t, fmt.Sprintf("invalid column name: %q", lInvalidColumnName), lError, "Table.AddColumn") -} - func TestAddColumnFailsWhenAlreadyExists(t *testing.T) { lProductAttributeTable := testutils.FakeProductAttributeTable(t) - if lProductAttributeTable == nil { - return - } - lError := lProductAttributeTable.AddColumn(schema.NewColumnParams{ + lError := lProductAttributeTable.AddColumns(schema.NewColumnParams{ Name: "value", Type: schema.Integer, }) - if !testutils.AssertNotNilError(t, lError, "Table.AddColumn") { + if !testutils.AssertNotNilError(t, lError, "Table.AddColumns") { return } - testutils.AssertErrorMessage(t, fmt.Sprintf("column %q already exists in table %q", "value", "product_attribute"), lError, "Table.AddColumn") + testutils.AssertErrorMessage(t, fmt.Sprintf("column %q already exists in table %q", "value", "product_attribute"), lError, "Table.AddColumns") }