Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -227,12 +227,12 @@ func (m M0001CreateUserTable) Up(migrator *migrations.Migrator) error {

func (m M0001CreateUserTable) Down(migrator *migrations.Migrator) error {
db := migrator.Database()
tableName := "user"
if !db.DBSchema().HasTable(tableName) {
table := "user"
if !db.DBSchema().HasTable(table) {
// if the table already doesn't exist, nothing to do here
return nil
}
_, err := db.SQLExecutor().Exec(db.SQLBuilder().DropTable(tableName))
_, err := db.SQLExecutor().Exec(db.SQLBuilder().DropTable(table))
return err
}
```
Expand All @@ -245,28 +245,28 @@ type M0002AddUpdatedAtColumn struct{}

func (m M0002AddUpdatedAtColumn) Up(migrator *migrations.Migrator) error {
db := migrator.Database()
tableName := "user"
table := "user"
col := schema.NewColumn(schema.NewColumnParams{Name: "updated_at", Type: schema.DateTime})

if db.DBSchema().HasColumn(tableName, col.Name()) {
if db.DBSchema().HasColumn(table, col.Name()) {
// if the column already exists, nothing to do here
return nil
}

_, err := db.SQLExecutor().Exec(db.SQLBuilder().AlterTableAddColumn(tableName, col))
_, err := db.SQLExecutor().Exec(db.SQLBuilder().AlterTableAddColumn(table, col))
return err
}

func (m M0002AddUpdatedAtColumn) Down(migrator *migrations.Migrator) error {
db := migrator.Database()
tableName := "user"
table := "user"
colName := "updated_at"

if !db.DBSchema().HasColumn(tableName, colName) {
if !db.DBSchema().HasColumn(table, colName) {
// if the column already doesn't exist, nothing to do here
return nil
}
_, err := db.SQLExecutor().Exec(db.SQLBuilder().AlterTableDropColumn(tableName, colName))
_, err := db.SQLExecutor().Exec(db.SQLBuilder().AlterTableDropColumn(table, colName))
return err
}
```
Expand Down
68 changes: 34 additions & 34 deletions builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,89 +7,89 @@ import (
"github.com/ordershift/ormshift/schema"
)

// DDSQLBuilder creates DDL (Data Definition Language) SQL commands for defining schema in DBMS.
// DDLSQLBuilder creates DDL (Data Definition Language) SQL commands for defining schema in DBMS.
type DDLSQLBuilder interface {
CreateTable(pTable schema.Table) string
DropTable(pTableName string) string
AlterTableAddColumn(pTableName string, pColumn schema.Column) string
AlterTableDropColumn(pTableName, pColumnName string) string
ColumnTypeAsString(pColumnType schema.ColumnType) string
CreateTable(table schema.Table) string
DropTable(table string) string
AlterTableAddColumn(table string, column schema.Column) string
AlterTableDropColumn(table, column string) string
ColumnTypeAsString(columnType schema.ColumnType) string
}

// ColumnsValues represents a mapping between column names and their corresponding values.
type ColumnsValues map[string]any

// ToNamedArgs transforms ColumnsValues to a sql.NamedArg array ordered by name, e.g.:
//
// lColumnsValues := ColumnsValues{"id": 5, "sku": "ZTX-9000", "is_simple": true}
// lNamedArgs := lColumnsValues.ToNamedArgs()
// //lNamedArgs == []sql.NamedArg{{Name: "id", Value: 5},{Name: "is_simple", Value: true},{Name: "sku", Value: "ZTX-9000"}}
// values := ColumnsValues{"id": 5, "sku": "ZTX-9000", "is_simple": true}
// args := values.ToNamedArgs()
// //args == []sql.NamedArg{{Name: "id", Value: 5},{Name: "is_simple", Value: true},{Name: "sku", Value: "ZTX-9000"}}
func (cv *ColumnsValues) ToNamedArgs() []sql.NamedArg {
lNamedArgs := []sql.NamedArg{}
args := []sql.NamedArg{}
for c, v := range *cv {
lNamedArgs = append(lNamedArgs, sql.Named(c, v))
args = append(args, sql.Named(c, v))
}
slices.SortFunc(lNamedArgs, func(a, b sql.NamedArg) int {
slices.SortFunc(args, func(a, b sql.NamedArg) int {
if a.Name < b.Name {
return -1
}
return 1
})
return lNamedArgs
return args
}

// ToColumns returns the column names from ColumnsValues as a string array ordered by name, e.g.:
func (cv *ColumnsValues) ToColumns() []string {
lColumns := []string{}
columns := []string{}
for c := range *cv {
lColumns = append(lColumns, c)
columns = append(columns, c)
}
slices.Sort(lColumns)
return lColumns
slices.Sort(columns)
return columns
}

// DMLSQLBuilder creates DML (Data Manipulation Language) SQL commands for manipulating data in DBMS.
type DMLSQLBuilder interface {
Insert(pTableName string, pColumns []string) string
InsertWithValues(pTableName string, pColumnsValues ColumnsValues) (string, []any)
Update(pTableName string, pColumns, pColumnsWhere []string) string
UpdateWithValues(pTableName string, pColumns, pColumnsWhere []string, pValues ColumnsValues) (string, []any)
Delete(pTableName string, pColumnsWhere []string) string
DeleteWithValues(pTableName string, pWhereColumnsValues ColumnsValues) (string, []any)
Select(pTableName string, pColumns, pColumnsWhere []string) string
SelectWithValues(pTableName string, pColumns []string, pWhereColumnsValues ColumnsValues) (string, []any)
SelectWithPagination(pSQLSelectCommand string, pRowsPerPage, pPageNumber uint) string
Insert(table string, columns []string) string
InsertWithValues(table string, values ColumnsValues) (string, []any)
Update(table string, columns, where []string) string
UpdateWithValues(table string, columns, where []string, values ColumnsValues) (string, []any)
Delete(table string, where []string) string
DeleteWithValues(table string, where ColumnsValues) (string, []any)
Select(table string, columns, where []string) string
SelectWithValues(table string, columns []string, where ColumnsValues) (string, []any)
SelectWithPagination(sql string, size, number uint) string

// InteroperateSQLCommandWithNamedArgs acts as a SQL command translator that standardizes SQL commands according to the database driver being used e.g.,
//
// pSQLCommand := "select * from user where id = @id"
// pNamedArg := sql.Named("id", 123)
// sql := "select * from user where id = @id"
// namedArg := sql.Named("id", 123)
//
// PostgreSQL:
// q, p := sqlbuilder.InteroperateSQLCommandWithNamedArgs(pSQLCommand, pNamedArg)
// q, p := sqlbuilder.InteroperateSQLCommandWithNamedArgs(sql, namedArg)
// //q == "select * from user where id = $1"
// //p == 123
//
// SQLite:
// q, p = sqlbuilder.InteroperateSQLCommandWithNamedArgs(pSQLCommand, pNamedArg)
// q, p = sqlbuilder.InteroperateSQLCommandWithNamedArgs(sql, namedArg)
// //q == "select * from user where id = @id"
// //p == sql.Named("id", 123)
//
// SQL Server:
// q, p = sqlbuilder.InteroperateSQLCommandWithNamedArgs(pSQLCommand, pNamedArg)
// q, p = sqlbuilder.InteroperateSQLCommandWithNamedArgs(sql, namedArg)
// //q == "select * from user where id = @id"
// //p == sql.Named("id", 123)
//
// MySQL (not yet supported, expects question marks in parameters):
//
// q, p = sqlbuilder.InteroperateSQLCommandWithNamedArgs(pSQLCommand, pNamedArg)
// q, p = sqlbuilder.InteroperateSQLCommandWithNamedArgs(sql, namedArg)
// //q == "select * from user where id = ?"
// //p == 123
InteroperateSQLCommandWithNamedArgs(pSQLCommand string, pNamedArgs ...sql.NamedArg) (string, []any)
InteroperateSQLCommandWithNamedArgs(sql string, args ...sql.NamedArg) (string, []any)
}

type SQLBuilder interface {
DDLSQLBuilder
DMLSQLBuilder
QuoteIdentifier(pIdentifier string) string
QuoteIdentifier(identifier string) string
}
28 changes: 14 additions & 14 deletions builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,21 @@ import (
)

func TestColumnsValuesToNamedArgs(t *testing.T) {
lColumnsValues := ormshift.ColumnsValues{"id": 1, "sku": "ABC1234", "active": true}
lNamedArgs := lColumnsValues.ToNamedArgs()
testutils.AssertEqualWithLabel(t, 3, len(lNamedArgs), "ColumnsValues.ToNamedArgs")
testutils.AssertEqualWithLabel(t, lNamedArgs[0].Name, "active", "ColumnsValues.ToNamedArgs[0].Name")
testutils.AssertEqualWithLabel(t, lNamedArgs[0].Value, true, "ColumnsValues.ToNamedArgs[0].Value")
testutils.AssertEqualWithLabel(t, lNamedArgs[1].Name, "id", "ColumnsValues.ToNamedArgs[1].Name")
testutils.AssertEqualWithLabel(t, lNamedArgs[1].Value, 1, "ColumnsValues.ToNamedArgs[1].Value")
testutils.AssertEqualWithLabel(t, lNamedArgs[2].Name, "sku", "ColumnsValues.ToNamedArgs[2].Name")
testutils.AssertEqualWithLabel(t, lNamedArgs[2].Value, "ABC1234", "ColumnsValues.ToNamedArgs[2].Value")
values := ormshift.ColumnsValues{"id": 1, "sku": "ABC1234", "active": true}
args := values.ToNamedArgs()
testutils.AssertEqualWithLabel(t, 3, len(args), "ColumnsValues.ToNamedArgs")
testutils.AssertEqualWithLabel(t, args[0].Name, "active", "ColumnsValues.ToNamedArgs[0].Name")
testutils.AssertEqualWithLabel(t, args[0].Value, true, "ColumnsValues.ToNamedArgs[0].Value")
testutils.AssertEqualWithLabel(t, args[1].Name, "id", "ColumnsValues.ToNamedArgs[1].Name")
testutils.AssertEqualWithLabel(t, args[1].Value, 1, "ColumnsValues.ToNamedArgs[1].Value")
testutils.AssertEqualWithLabel(t, args[2].Name, "sku", "ColumnsValues.ToNamedArgs[2].Name")
testutils.AssertEqualWithLabel(t, args[2].Value, "ABC1234", "ColumnsValues.ToNamedArgs[2].Value")
}

func TestColumnsValuesToColumns(t *testing.T) {
lColumnsValues := ormshift.ColumnsValues{"id": 1, "sku": "ABC1234"}
lColumns := lColumnsValues.ToColumns()
testutils.AssertEqualWithLabel(t, 2, len(lColumns), "ColumnsValues.ToColumns")
testutils.AssertEqualWithLabel(t, lColumns[0], "id", "ColumnsValues.ToColumns[0]")
testutils.AssertEqualWithLabel(t, lColumns[1], "sku", "ColumnsValues.ToColumns[1]")
values := ormshift.ColumnsValues{"id": 1, "sku": "ABC1234"}
columns := values.ToColumns()
testutils.AssertEqualWithLabel(t, 2, len(columns), "ColumnsValues.ToColumns")
testutils.AssertEqualWithLabel(t, columns[0], "id", "ColumnsValues.ToColumns[0]")
testutils.AssertEqualWithLabel(t, columns[1], "sku", "ColumnsValues.ToColumns[1]")
}
32 changes: 16 additions & 16 deletions database.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ type ConnectionParams struct {

type DatabaseDriver interface {
Name() string
ConnectionString(pParams ConnectionParams) string
ConnectionString(params ConnectionParams) string
SQLBuilder() SQLBuilder
DBSchema(pDB *sql.DB) (*schema.DBSchema, error)
DBSchema(db *sql.DB) (*schema.DBSchema, error)
}

type Database struct {
Expand All @@ -33,26 +33,26 @@ type Database struct {
dbSchema *schema.DBSchema
}

func OpenDatabase(pDriver DatabaseDriver, pParams ConnectionParams) (*Database, error) {
if pDriver == nil {
func OpenDatabase(driver DatabaseDriver, params ConnectionParams) (*Database, error) {
if driver == nil {
return nil, errors.New("DatabaseDriver cannot be nil")
}
lConnectionString := pDriver.ConnectionString(pParams)
lDB, lError := sql.Open(pDriver.Name(), lConnectionString)
if lError != nil {
return nil, fmt.Errorf("sql.Open failed: %w", lError)
connectionString := driver.ConnectionString(params)
db, err := sql.Open(driver.Name(), connectionString)
if err != nil {
return nil, fmt.Errorf("sql.Open failed: %w", err)
}
lDBSchema, lError := pDriver.DBSchema(lDB)
if lError != nil {
return nil, fmt.Errorf("failed to get DB schema: %w", lError)
dbSchema, err := driver.DBSchema(db)
if err != nil {
return nil, fmt.Errorf("failed to get DB schema: %w", err)
}

return &Database{
driver: pDriver,
db: lDB,
connectionString: lConnectionString,
sqlBuilder: pDriver.SQLBuilder(),
dbSchema: lDBSchema,
driver: driver,
db: db,
connectionString: connectionString,
sqlBuilder: driver.SQLBuilder(),
dbSchema: dbSchema,
}, nil
}

Expand Down
Loading