Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 36 additions & 25 deletions orm.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func reflectStruct(s interface{}, cols []string, row *sql.Rows) error {
}
func reflectStructValue(v reflect.Value, t reflect.Type, cols []string, row *sql.Rows) error {
if v.Kind() != reflect.Ptr {
panic(errors.New("holder should be pointer"))
return errors.New("holder should be pointer")
}
v = v.Elem()
targets := make([]interface{}, len(cols))
Expand Down Expand Up @@ -357,15 +357,15 @@ type orColumn struct {
/**
返回三个值,一个是struct主键的field,一个是主键对应的数据库的值 ,一个是[]*orColumn
*/
func getOrColumns(s interface{}) (string, string, []*orColumn) {
func getOrColumns(s interface{}) (string, string, []*orColumn, error) {
t := reflect.TypeOf(s).Elem()
return getOrColumnsByType(t)
}

/**
根据struct中的值找到其他struct进行relation的联系,并组成[]*orColumn结构,返回三个值,一个是struct主键的field,一个是主键对应的数据库的值 ,一个是[]*orColumn
*/
func getOrColumnsByType(t reflect.Type) (string, string, []*orColumn) {
func getOrColumnsByType(t reflect.Type) (string, string, []*orColumn, error) {
res := make([]*orColumn, 0)
pkCol := ""
pkField := ""
Expand All @@ -378,27 +378,27 @@ func getOrColumnsByType(t reflect.Type) (string, string, []*orColumn) {
var orType reflect.Type
if orTag == "has_one" {
if ft.Type.Kind() != reflect.Ptr {
panic(errors.New(ft.Name + " should be pointer"))
return "", "", res, errors.New(ft.Name + " should be pointer")
}
orType = ft.Type.Elem()
} else if orTag == "has_many" {
if ft.Type.Kind() != reflect.Slice {
panic(errors.New(ft.Name + " should be slice of pointer"))
return "", "", res, errors.New(ft.Name + " should be slice of pointer")
}
elemType := ft.Type.Elem()
if elemType.Kind() != reflect.Ptr {
panic(errors.New(ft.Name + " should be slice of pointer"))
return "", "", res, errors.New(ft.Name + " should be slice of pointer")
}
orType = elemType.Elem()
} else if orTag == "belongs_to" {
if ft.Type.Kind() != reflect.Ptr {
panic(errors.New(ft.Name + " should be pointer"))
return "", "", res, errors.New(ft.Name + " should be pointer")
}
orType = ft.Type.Elem()
}
orTableName := ft.Tag.Get("table")
if orTableName == "" {
panic(errors.New("invalid table name in or tag on field: " + ft.Name))
return "", "", res, errors.New("invalid table name in or tag on field: " + ft.Name)
}
res = append(res, &orColumn{
fieldName: ft.Name,
Expand All @@ -407,7 +407,7 @@ func getOrColumnsByType(t reflect.Type) (string, string, []*orColumn) {
orType: orType,
})
} else {
panic(errors.New("unsupported or tag: " + orTag + ", only support has_one, has_many and belongs_to for now"))
return "", "", res, errors.New("unsupported or tag: " + orTag + ", only support has_one, has_many and belongs_to for now")
}
}
dbTag := ft.Tag.Get("db")
Expand All @@ -426,7 +426,7 @@ func getOrColumnsByType(t reflect.Type) (string, string, []*orColumn) {
pkField = ft.Name
}
}
return pkField, pkCol, res
return pkField, pkCol, res, nil
}

/**
Expand Down Expand Up @@ -463,7 +463,10 @@ func selectOne(c context.Context, tdx Tdx, s interface{}, query string, args ...
if err != nil {
return err
}
pkField, pkCol, orColumns := getOrColumns(s)
pkField, pkCol, orColumns, err := getOrColumns(s)
if err != nil {
return err
}
if orColumns != nil && len(orColumns) > 0 {
v := reflect.ValueOf(s).Elem()
pkValue, err := getFieldValue(s, pkField)
Expand All @@ -486,7 +489,7 @@ func selectOne(c context.Context, tdx Tdx, s interface{}, query string, args ...
} else if orCol.or == "belongs_to" {
fk := getPkColumnByType(orCol.orType)
if fk == "" {
panic(errors.New("error while getting primary key of " + orCol.table + " for belongs_to"))
return errors.New("error while getting primary key of " + orCol.table + " for belongs_to")
}
fkValue, err := getFieldValue(s, colName2FieldName(fk))
if err != nil {
Expand Down Expand Up @@ -809,7 +812,10 @@ func selectManyInternal(c context.Context, tdx Tdx, s interface{}, processOr boo
if isPtr {
t = t.Elem()
if processOr {
pkField, pkCol, orCols = getOrColumnsByType(t)
pkField, pkCol, orCols, err = getOrColumnsByType(t)
if err != nil {
return err
}
if orCols != nil && len(orCols) > 0 {
hasOrCols = true
}
Expand Down Expand Up @@ -844,7 +850,7 @@ func selectManyInternal(c context.Context, tdx Tdx, s interface{}, processOr boo
}
fv := v.Elem().FieldByName(fName)
if !fv.CanAddr() {
logrus.WithField("sql", queryStr).Errorf("missing field: %s", fName)
logrus.WithField("sql", queryStr).Warnf("missing field: %s", fName)
var b interface{}
targets[k] = &b
} else {
Expand Down Expand Up @@ -983,7 +989,7 @@ var zeroTime = time.Unix(1, 0)
func columnsByStructFields(s interface{}, cols []string) ([]interface{}, reflect.Value, bool, string) {
t := reflect.TypeOf(s).Elem()
v := reflect.ValueOf(s).Elem()
ret := make([]interface{}, 0, len(cols))
ret := make([]interface{}, len(cols))
var pk reflect.Value
var pkName string
isAi := false
Expand All @@ -1006,19 +1012,24 @@ func columnsByStructFields(s interface{}, cols []string) ([]interface{}, reflect
if ft.Tag.Get("ai") == "true" || isPkOrAi(dbTag, "ai") {
isAi = true
}
break
}
}
//通过cols获取struct中的值
for _, value := range cols {
value = colName2FieldName(value)
r := v.FieldByName(value).Addr().Interface()
if v.FieldByName(value).Type().String() == "time.Time" {
if r.(*time.Time).IsZero() {
r = &zeroTime

//通过db 标签和cols 做对比 取出struct中的值
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

主要修改的这个地方

realCol := ""
if dbCol != "" {
realCol = dbCol
} else {
realCol = t.Field(k).Name
}
if kCol := IsContain(realCol, cols); kCol != -1 {
r := v.Field(k).Addr().Interface()
if v.Field(k).Type().String() == "time.Time" {
if r.(*time.Time).IsZero() {
r = &zeroTime
}
}
ret[kCol] = r
}
ret = append(ret, r)
}
return ret, pk, isAi, pkName
}
Expand Down
2 changes: 1 addition & 1 deletion orm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
type TestOrmA123 struct {
TestID int64 `json:"test_id" pk:"true" ai:"true" db:"test_id,ai,pk"`
OtherId int64
Description string
Description string `json:"description" db:"description"`
Name sql.NullString
StartDate time.Time
EndDate time.Time
Expand Down
8 changes: 8 additions & 0 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@ import (
"github.com/sirupsen/logrus"
)

func IsContain(str string,arr []string) int {
for k,v:=range arr {
if v == str {
return k
}
}
return -1
}
//替换query 中??为长度为len的?
func getNumInStr(len int, query string) string {
str := ""
Expand Down