1
0
mirror of https://github.com/gogf/gf.git synced 2025-04-05 11:18:50 +08:00

fix(database/gdb): issue where the Count/Value/Array query logic was incompatible with the old version when users extended the returned result fields using the Select Hook (#3995)

This commit is contained in:
John Guo 2024-12-01 23:47:51 +08:00 committed by GitHub
parent 42eae41599
commit 2c916f8222
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 132 additions and 55 deletions

View File

@ -13,6 +13,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/gogf/gf/v2/container/gvar"
"github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gtime" "github.com/gogf/gf/v2/os/gtime"
@ -1283,12 +1284,12 @@ func Test_Issue3754(t *testing.T) {
func Test_Issue3626(t *testing.T) { func Test_Issue3626(t *testing.T) {
table := "issue3626" table := "issue3626"
array := gstr.SplitAndTrim(gtest.DataContent(`issue3626.sql`), ";") array := gstr.SplitAndTrim(gtest.DataContent(`issue3626.sql`), ";")
defer dropTable(table)
for _, v := range array { for _, v := range array {
if _, err := db.Exec(ctx, v); err != nil { if _, err := db.Exec(ctx, v); err != nil {
gtest.Error(err) gtest.Error(err)
} }
} }
defer dropTable(table)
// Insert. // Insert.
gtest.C(t, func(t *gtest.T) { gtest.C(t, func(t *gtest.T) {
@ -1377,3 +1378,34 @@ func Test_Issue3932(t *testing.T) {
t.Assert(one["id"], 10) t.Assert(one["id"], 10)
}) })
} }
// https://github.com/gogf/gf/issues/3968
func Test_Issue3968(t *testing.T) {
table := createInitTable()
defer dropTable(table)
gtest.C(t, func(t *gtest.T) {
var hook = gdb.HookHandler{
Select: func(ctx context.Context, in *gdb.HookSelectInput) (result gdb.Result, err error) {
result, err = in.Next(ctx)
if err != nil {
return nil, err
}
if result != nil {
for i, _ := range result {
result[i]["location"] = gvar.New("ny")
}
}
return
},
}
var (
count int
result gdb.Result
)
err := db.Model(table).Hook(hook).ScanAndCount(&result, &count, false)
t.AssertNil(err)
t.Assert(count, 10)
t.Assert(len(result), 10)
})
}

View File

@ -396,12 +396,13 @@ const (
linkPattern = `(\w+):([\w\-\$]*):(.*?)@(\w+?)\((.+?)\)/{0,1}([^\?]*)\?{0,1}(.*)` linkPattern = `(\w+):([\w\-\$]*):(.*?)@(\w+?)\((.+?)\)/{0,1}([^\?]*)\?{0,1}(.*)`
) )
type queryType int type SelectType int
const ( const (
queryTypeNormal queryType = iota SelectTypeDefault SelectType = iota
queryTypeCount SelectTypeCount
queryTypeValue SelectTypeValue
SelectTypeArray
) )
type joinOperator string type joinOperator string
@ -700,13 +701,13 @@ func getConfigNodeByWeight(cg ConfigGroup) *ConfigNode {
} }
// Exclude the right border value. // Exclude the right border value.
var ( var (
min = 0 minWeight = 0
max = 0 maxWeight = 0
random = grand.N(0, total-1) random = grand.N(0, total-1)
) )
for i := 0; i < len(cg); i++ { for i := 0; i < len(cg); i++ {
max = min + cg[i].Weight*100 maxWeight = minWeight + cg[i].Weight*100
if random >= min && random < max { if random >= minWeight && random < maxWeight {
// ==================================================== // ====================================================
// Return a COPY of the ConfigNode. // Return a COPY of the ConfigNode.
// ==================================================== // ====================================================
@ -714,7 +715,7 @@ func getConfigNodeByWeight(cg ConfigGroup) *ConfigNode {
node = cg[i] node = cg[i]
return &node return &node
} }
min = max minWeight = maxWeight
} }
return nil return nil
} }

View File

@ -278,7 +278,7 @@ func (c *Core) doUnion(ctx context.Context, unionType int, unions ...*Model) *Mo
unionTypeStr = "UNION" unionTypeStr = "UNION"
} }
for _, v := range unions { for _, v := range unions {
sqlWithHolder, holderArgs := v.getFormattedSqlAndArgs(ctx, queryTypeNormal, false) sqlWithHolder, holderArgs := v.getFormattedSqlAndArgs(ctx, SelectTypeDefault, false)
if composedSqlStr == "" { if composedSqlStr == "" {
composedSqlStr += fmt.Sprintf(`(%s)`, sqlWithHolder) composedSqlStr += fmt.Sprintf(`(%s)`, sqlWithHolder)
} else { } else {

View File

@ -23,8 +23,6 @@ type internalCtxData struct {
} }
// column stores column data in ctx for internal usage purpose. // column stores column data in ctx for internal usage purpose.
// Deprecated.
// TODO remove this usage in future.
type internalColumnData struct { type internalColumnData struct {
// The first column in result response from database server. // The first column in result response from database server.
// This attribute is used for Value/Count selection statement purpose, // This attribute is used for Value/Count selection statement purpose,

View File

@ -90,7 +90,7 @@ func (m *Model) getSelectResultFromCache(ctx context.Context, sql string, args .
} }
func (m *Model) saveSelectResultToCache( func (m *Model) saveSelectResultToCache(
ctx context.Context, queryType queryType, result Result, sql string, args ...interface{}, ctx context.Context, selectType SelectType, result Result, sql string, args ...interface{},
) (err error) { ) (err error) {
if !m.cacheEnabled || m.tx != nil { if !m.cacheEnabled || m.tx != nil {
return return
@ -108,18 +108,19 @@ func (m *Model) saveSelectResultToCache(
// Special handler for Value/Count operations result. // Special handler for Value/Count operations result.
if len(result) > 0 { if len(result) > 0 {
var core = m.db.GetCore() var core = m.db.GetCore()
switch queryType { switch selectType {
case queryTypeValue, queryTypeCount: case SelectTypeValue, SelectTypeArray, SelectTypeCount:
if internalData := core.getInternalColumnFromCtx(ctx); internalData != nil { if internalData := core.getInternalColumnFromCtx(ctx); internalData != nil {
if result[0][internalData.FirstResultColumn].IsEmpty() { if result[0][internalData.FirstResultColumn].IsEmpty() {
result = nil result = nil
} }
} }
default:
} }
} }
// In case of Cache Penetration. // In case of Cache Penetration.
if result.IsEmpty() { if result != nil && result.IsEmpty() {
if m.cacheOption.Force { if m.cacheOption.Force {
result = Result{} result = Result{}
} else { } else {

View File

@ -66,11 +66,12 @@ type internalParamHookDelete struct {
// which is usually not be interesting for upper business hook handler. // which is usually not be interesting for upper business hook handler.
type HookSelectInput struct { type HookSelectInput struct {
internalParamHookSelect internalParamHookSelect
Model *Model // Current operation Model. Model *Model // Current operation Model.
Table string // The table name that to be used. Update this attribute to change target table name. Table string // The table name that to be used. Update this attribute to change target table name.
Schema string // The schema name that to be used. Update this attribute to change target schema name. Schema string // The schema name that to be used. Update this attribute to change target schema name.
Sql string // The sql string that to be committed. Sql string // The sql string that to be committed.
Args []interface{} // The arguments of sql. Args []interface{} // The arguments of sql.
SelectType SelectType // The type of this SELECT operation.
} }
// HookInsertInput holds the parameters for insert hook operation. // HookInsertInput holds the parameters for insert hook operation.

View File

@ -28,7 +28,7 @@ import (
// see Model.Where. // see Model.Where.
func (m *Model) All(where ...interface{}) (Result, error) { func (m *Model) All(where ...interface{}) (Result, error) {
var ctx = m.GetCtx() var ctx = m.GetCtx()
return m.doGetAll(ctx, false, where...) return m.doGetAll(ctx, SelectTypeDefault, false, where...)
} }
// AllAndCount retrieves all records and the total count of records from the model. // AllAndCount retrieves all records and the total count of records from the model.
@ -69,7 +69,7 @@ func (m *Model) AllAndCount(useFieldForCount bool) (result Result, totalCount in
} }
// Retrieve all records // Retrieve all records
result, err = m.doGetAll(m.GetCtx(), false) result, err = m.doGetAll(m.GetCtx(), SelectTypeDefault, false)
return return
} }
@ -110,7 +110,7 @@ func (m *Model) One(where ...interface{}) (Record, error) {
if len(where) > 0 { if len(where) > 0 {
return m.Where(where[0], where[1:]...).One() return m.Where(where[0], where[1:]...).One()
} }
all, err := m.doGetAll(ctx, true) all, err := m.doGetAll(ctx, SelectTypeDefault, true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -136,24 +136,41 @@ func (m *Model) Array(fieldsAndWhere ...interface{}) ([]Value, error) {
return m.Fields(gconv.String(fieldsAndWhere[0])).Array() return m.Fields(gconv.String(fieldsAndWhere[0])).Array()
} }
} }
all, err := m.All()
var (
field string
core = m.db.GetCore()
ctx = core.injectInternalColumn(m.GetCtx())
)
all, err := m.doGetAll(ctx, SelectTypeArray, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var field string
if len(all) > 0 { if len(all) > 0 {
var recordFields = m.getRecordFields(all[0]) internalData := core.getInternalColumnFromCtx(ctx)
if len(recordFields) > 1 { if internalData == nil {
// it returns error if there are multiple fields in the result record. return nil, gerror.NewCode(
return nil, gerror.NewCodef( gcode.CodeInternalError,
gcode.CodeInvalidParameter, `query count error: the internal context data is missing. there's internal issue should be fixed`,
`invalid fields for "Array" operation, result fields number "%d"%s, but expect one`,
len(recordFields),
gjson.MustEncodeString(recordFields),
) )
} }
if len(recordFields) == 1 { // If FirstResultColumn present, it returns the value of the first record of the first field.
field = recordFields[0] // It means it use no cache mechanism, while cache mechanism makes `internalData` missing.
field = internalData.FirstResultColumn
if field == "" {
// Fields number check.
var recordFields = m.getRecordFields(all[0])
if len(recordFields) == 1 {
field = recordFields[0]
} else {
// it returns error if there are multiple fields in the result record.
return nil, gerror.NewCodef(
gcode.CodeInvalidParameter,
`invalid fields for "Array" operation, result fields number "%d"%s, but expect one`,
len(recordFields),
gjson.MustEncodeString(recordFields),
)
}
} }
} }
return all.Array(field), nil return all.Array(field), nil
@ -398,13 +415,26 @@ func (m *Model) Value(fieldsAndWhere ...interface{}) (Value, error) {
} }
} }
var ( var (
sqlWithHolder, holderArgs = m.getFormattedSqlAndArgs(ctx, queryTypeValue, true) sqlWithHolder, holderArgs = m.getFormattedSqlAndArgs(ctx, SelectTypeValue, true)
all, err = m.doGetAllBySql(ctx, queryTypeValue, sqlWithHolder, holderArgs...) all, err = m.doGetAllBySql(ctx, SelectTypeValue, sqlWithHolder, holderArgs...)
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(all) > 0 { if len(all) > 0 {
internalData := core.getInternalColumnFromCtx(ctx)
if internalData == nil {
return nil, gerror.NewCode(
gcode.CodeInternalError,
`query count error: the internal context data is missing. there's internal issue should be fixed`,
)
}
// If FirstResultColumn present, it returns the value of the first record of the first field.
// It means it use no cache mechanism, while cache mechanism makes `internalData` missing.
if v, ok := all[0][internalData.FirstResultColumn]; ok {
return v, nil
}
// Fields number check.
var recordFields = m.getRecordFields(all[0]) var recordFields = m.getRecordFields(all[0])
if len(recordFields) == 1 { if len(recordFields) == 1 {
for _, v := range all[0] { for _, v := range all[0] {
@ -445,13 +475,26 @@ func (m *Model) Count(where ...interface{}) (int, error) {
return m.Where(where[0], where[1:]...).Count() return m.Where(where[0], where[1:]...).Count()
} }
var ( var (
sqlWithHolder, holderArgs = m.getFormattedSqlAndArgs(ctx, queryTypeCount, false) sqlWithHolder, holderArgs = m.getFormattedSqlAndArgs(ctx, SelectTypeCount, false)
all, err = m.doGetAllBySql(ctx, queryTypeCount, sqlWithHolder, holderArgs...) all, err = m.doGetAllBySql(ctx, SelectTypeCount, sqlWithHolder, holderArgs...)
) )
if err != nil { if err != nil {
return 0, err return 0, err
} }
if len(all) > 0 { if len(all) > 0 {
internalData := core.getInternalColumnFromCtx(ctx)
if internalData == nil {
return 0, gerror.NewCode(
gcode.CodeInternalError,
`query count error: the internal context data is missing. there's internal issue should be fixed`,
)
}
// If FirstResultColumn present, it returns the value of the first record of the first field.
// It means it use no cache mechanism, while cache mechanism makes `internalData` missing.
if v, ok := all[0][internalData.FirstResultColumn]; ok {
return v.Int(), nil
}
// Fields number check.
var recordFields = m.getRecordFields(all[0]) var recordFields = m.getRecordFields(all[0])
if len(recordFields) == 1 { if len(recordFields) == 1 {
for _, v := range all[0] { for _, v := range all[0] {
@ -616,17 +659,17 @@ func (m *Model) Having(having interface{}, args ...interface{}) *Model {
// The parameter `limit1` specifies whether limits querying only one record if m.limit is not set. // The parameter `limit1` specifies whether limits querying only one record if m.limit is not set.
// The optional parameter `where` is the same as the parameter of Model.Where function, // The optional parameter `where` is the same as the parameter of Model.Where function,
// see Model.Where. // see Model.Where.
func (m *Model) doGetAll(ctx context.Context, limit1 bool, where ...interface{}) (Result, error) { func (m *Model) doGetAll(ctx context.Context, selectType SelectType, limit1 bool, where ...interface{}) (Result, error) {
if len(where) > 0 { if len(where) > 0 {
return m.Where(where[0], where[1:]...).All() return m.Where(where[0], where[1:]...).All()
} }
sqlWithHolder, holderArgs := m.getFormattedSqlAndArgs(ctx, queryTypeNormal, limit1) sqlWithHolder, holderArgs := m.getFormattedSqlAndArgs(ctx, selectType, limit1)
return m.doGetAllBySql(ctx, queryTypeNormal, sqlWithHolder, holderArgs...) return m.doGetAllBySql(ctx, selectType, sqlWithHolder, holderArgs...)
} }
// doGetAllBySql does the select statement on the database. // doGetAllBySql does the select statement on the database.
func (m *Model) doGetAllBySql( func (m *Model) doGetAllBySql(
ctx context.Context, queryType queryType, sql string, args ...interface{}, ctx context.Context, selectType SelectType, sql string, args ...interface{},
) (result Result, err error) { ) (result Result, err error) {
if result, err = m.getSelectResultFromCache(ctx, sql, args...); err != nil || result != nil { if result, err = m.getSelectResultFromCache(ctx, sql, args...); err != nil || result != nil {
return return
@ -639,24 +682,25 @@ func (m *Model) doGetAllBySql(
}, },
handler: m.hookHandler.Select, handler: m.hookHandler.Select,
}, },
Model: m, Model: m,
Table: m.tables, Table: m.tables,
Sql: sql, Sql: sql,
Args: m.mergeArguments(args), Args: m.mergeArguments(args),
SelectType: selectType,
} }
if result, err = in.Next(ctx); err != nil { if result, err = in.Next(ctx); err != nil {
return return
} }
err = m.saveSelectResultToCache(ctx, queryType, result, sql, args...) err = m.saveSelectResultToCache(ctx, selectType, result, sql, args...)
return return
} }
func (m *Model) getFormattedSqlAndArgs( func (m *Model) getFormattedSqlAndArgs(
ctx context.Context, queryType queryType, limit1 bool, ctx context.Context, selectType SelectType, limit1 bool,
) (sqlWithHolder string, holderArgs []interface{}) { ) (sqlWithHolder string, holderArgs []interface{}) {
switch queryType { switch selectType {
case queryTypeCount: case SelectTypeCount:
queryFields := "COUNT(1)" queryFields := "COUNT(1)"
if len(m.fields) > 0 { if len(m.fields) > 0 {
// DO NOT quote the m.fields here, in case of fields like: // DO NOT quote the m.fields here, in case of fields like:
@ -698,7 +742,7 @@ func (m *Model) getFormattedSqlAndArgs(
func (m *Model) getHolderAndArgsAsSubModel(ctx context.Context) (holder string, args []interface{}) { func (m *Model) getHolderAndArgsAsSubModel(ctx context.Context) (holder string, args []interface{}) {
holder, args = m.getFormattedSqlAndArgs( holder, args = m.getFormattedSqlAndArgs(
ctx, queryTypeNormal, false, ctx, SelectTypeDefault, false,
) )
args = m.mergeArguments(args) args = m.mergeArguments(args)
return return