1
0
mirror of https://github.com/gogf/gf.git synced 2025-04-05 11:18:50 +08:00

improve struct embedded association case of with feature for package gdb

This commit is contained in:
John Guo 2021-05-11 20:00:50 +08:00
parent 6a80091fef
commit 2e38416e12
5 changed files with 134 additions and 12 deletions

View File

@ -55,7 +55,7 @@ func (m *Model) WithAll() *Model {
// getWithTagObjectArrayFrom retrieves and returns object array that have "with" tag in the struct.
func (m *Model) getWithTagObjectArrayFrom(pointer interface{}) ([]interface{}, error) {
fieldMap, err := structs.FieldMap(pointer, nil)
fieldMap, err := structs.FieldMap(pointer, nil, false)
if err != nil {
return nil, err
}
@ -86,6 +86,7 @@ func (m *Model) doWithScanStruct(pointer interface{}) error {
err error
withArray = m.withArray
)
// If with all feature is enabled, it then retrieves all the attributes which have with tag defined.
if m.withAll {
withArray, err = m.getWithTagObjectArrayFrom(pointer)
if err != nil {
@ -95,10 +96,11 @@ func (m *Model) doWithScanStruct(pointer interface{}) error {
if len(withArray) == 0 {
return nil
}
fieldMap, err := structs.FieldMap(pointer, nil)
fieldMap, err := structs.FieldMap(pointer, nil, false)
if err != nil {
return err
}
// Check the with array and automatically call the ScanList to complete association querying.
for withIndex, withItem := range withArray {
withItemReflectValueType, err := structs.StructType(withItem)
if err != nil {
@ -110,6 +112,7 @@ func (m *Model) doWithScanStruct(pointer interface{}) error {
fieldType = fieldValue.Type()
fieldTypeStr = gstr.TrimAll(fieldType.String(), "*[]")
)
// It does select operation if the field type is in the specified with type array.
if gstr.Compare(fieldTypeStr, withItemReflectValueTypeStr) == 0 {
var (
withTag string
@ -174,6 +177,7 @@ func (m *Model) doWithScanStruct(pointer interface{}) error {
}
// doWithScanStructs handles model association operations feature for struct slice.
// Also see doWithScanStruct.
func (m *Model) doWithScanStructs(pointer interface{}) error {
var (
err error
@ -188,7 +192,7 @@ func (m *Model) doWithScanStructs(pointer interface{}) error {
if len(withArray) == 0 {
return nil
}
fieldMap, err := structs.FieldMap(pointer, nil)
fieldMap, err := structs.FieldMap(pointer, nil, false)
if err != nil {
return err
}

View File

@ -212,7 +212,7 @@ PRIMARY KEY (id)
})
}
func Test_Table_Relation_With_ScanList(t *testing.T) {
func Test_Table_Relation_With(t *testing.T) {
var (
tableUser = "user"
tableUserDetail = "user_detail"
@ -411,7 +411,7 @@ PRIMARY KEY (id)
})
}
func Test_Table_Relation_WithAll_Scan(t *testing.T) {
func Test_Table_Relation_WithAll(t *testing.T) {
var (
tableUser = "user"
tableUserDetail = "user_detail"
@ -526,7 +526,7 @@ PRIMARY KEY (id)
})
}
func Test_Table_Relation_WithAll_ScanList(t *testing.T) {
func Test_Table_Relation_WithAll_List(t *testing.T) {
var (
tableUser = "user"
tableUserDetail = "user_detail"
@ -666,3 +666,118 @@ PRIMARY KEY (id)
t.Assert(users[1].UserScores[4].Score, 5)
})
}
func Test_Table_Relation_WithAll_Embedded(t *testing.T) {
var (
tableUser = "user"
tableUserDetail = "user_detail"
tableUserScores = "user_scores"
)
if _, err := db.Exec(fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s (
id int(10) unsigned NOT NULL AUTO_INCREMENT,
name varchar(45) NOT NULL,
PRIMARY KEY (id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8;
`, tableUser)); err != nil {
gtest.Error(err)
}
defer dropTable(tableUser)
if _, err := db.Exec(fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s (
uid int(10) unsigned NOT NULL AUTO_INCREMENT,
address varchar(45) NOT NULL,
PRIMARY KEY (uid)
) ENGINE=InnoDB DEFAULT CHARSET=utf8;
`, tableUserDetail)); err != nil {
gtest.Error(err)
}
defer dropTable(tableUserDetail)
if _, err := db.Exec(fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s (
id int(10) unsigned NOT NULL AUTO_INCREMENT,
uid int(10) unsigned NOT NULL,
score int(10) unsigned NOT NULL,
PRIMARY KEY (id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8;
`, tableUserScores)); err != nil {
gtest.Error(err)
}
defer dropTable(tableUserScores)
type UserDetail struct {
gmeta.Meta `orm:"table:user_detail"`
Uid int `json:"uid"`
Address string `json:"address"`
}
type UserScores struct {
gmeta.Meta `orm:"table:user_scores"`
Id int `json:"id"`
Uid int `json:"uid"`
Score int `json:"score"`
}
type User struct {
gmeta.Meta `orm:"table:user"`
*UserDetail `orm:"with:uid=id"`
Id int `json:"id"`
Name string `json:"name"`
UserScores []*UserScores `orm:"with:uid=id"`
}
// Initialize the data.
var err error
for i := 1; i <= 5; i++ {
// User.
_, err = db.Insert(tableUser, g.Map{
"id": i,
"name": fmt.Sprintf(`name_%d`, i),
})
gtest.Assert(err, nil)
// Detail.
_, err = db.Insert(tableUserDetail, g.Map{
"uid": i,
"address": fmt.Sprintf(`address_%d`, i),
})
gtest.Assert(err, nil)
// Scores.
for j := 1; j <= 5; j++ {
_, err = db.Insert(tableUserScores, g.Map{
"uid": i,
"score": j,
})
gtest.Assert(err, nil)
}
}
gtest.C(t, func(t *gtest.T) {
var user *User
err := db.Model(tableUser).WithAll().Where("id", 3).Scan(&user)
t.AssertNil(err)
t.Assert(user.Id, 3)
t.AssertNE(user.UserDetail, nil)
t.Assert(user.UserDetail.Uid, 3)
t.Assert(user.UserDetail.Address, `address_3`)
t.Assert(len(user.UserScores), 5)
t.Assert(user.UserScores[0].Uid, 3)
t.Assert(user.UserScores[0].Score, 1)
t.Assert(user.UserScores[4].Uid, 3)
t.Assert(user.UserScores[4].Score, 5)
})
gtest.C(t, func(t *gtest.T) {
var user User
err := db.Model(tableUser).WithAll().Where("id", 4).Scan(&user)
t.AssertNil(err)
t.Assert(user.Id, 4)
t.AssertNE(user.UserDetail, nil)
t.Assert(user.UserDetail.Uid, 4)
t.Assert(user.UserDetail.Address, `address_4`)
t.Assert(len(user.UserScores), 5)
t.Assert(user.UserScores[0].Uid, 4)
t.Assert(user.UserScores[0].Score, 1)
t.Assert(user.UserScores[4].Uid, 4)
t.Assert(user.UserScores[4].Score, 5)
})
}

View File

@ -61,8 +61,11 @@ func (f *Field) OriginalKind() reflect.Kind {
// The parameter `priority` specifies the priority tag array for retrieving from high to low.
// If it's given `nil`, it returns map[name]*Field, of which the `name` is attribute name.
//
// The parameter `recursive` specifies the whether retrieving the fields recursively if the attribute
// is an embedded struct.
//
// Note that it only retrieves the exported attributes with first letter up-case from struct.
func FieldMap(pointer interface{}, priority []string) (map[string]*Field, error) {
func FieldMap(pointer interface{}, priority []string, recursive bool) (map[string]*Field, error) {
fields, err := getFieldValues(pointer)
if err != nil {
return nil, err
@ -88,8 +91,8 @@ func FieldMap(pointer interface{}, priority []string) (map[string]*Field, error)
if tagValue != "" {
mapField[tagValue] = tempField
} else {
if field.IsEmbedded() {
m, err := FieldMap(field.Value, priority)
if recursive && field.IsEmbedded() {
m, err := FieldMap(field.Value, priority, recursive)
if err != nil {
return nil, err
}

View File

@ -110,7 +110,7 @@ func Test_FieldMap(t *testing.T) {
Pass string `my-tag1:"pass1" my-tag2:"pass2" params:"pass"`
}
var user *User
m, _ := structs.FieldMap(user, []string{"params"})
m, _ := structs.FieldMap(user, []string{"params"}, true)
t.Assert(len(m), 3)
_, ok := m["Id"]
t.Assert(ok, true)
@ -130,7 +130,7 @@ func Test_FieldMap(t *testing.T) {
Pass string `my-tag1:"pass1" my-tag2:"pass2" params:"pass"`
}
var user *User
m, _ := structs.FieldMap(user, nil)
m, _ := structs.FieldMap(user, nil, true)
t.Assert(len(m), 3)
_, ok := m["Id"]
t.Assert(ok, true)

View File

@ -28,7 +28,7 @@ func (v *Validator) CheckStruct(object interface{}, rules interface{}, messages
var (
errorMaps = make(ErrorMap) // Returned error.
)
mapField, err := structs.FieldMap(object, aliasNameTagPriority)
mapField, err := structs.FieldMap(object, aliasNameTagPriority, true)
if err != nil {
return newErrorStr("invalid_object", err.Error())
}