mirror of
https://github.com/gogf/gf.git
synced 2025-04-05 03:05:05 +08:00
fix(database/gdb): issue where the Count/Value/Array
query logic was incompatible with the old version when users extended the returned result fields using the Select
Hook (#3995)
This commit is contained in:
parent
42eae41599
commit
2c916f8222
@ -13,6 +13,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gogf/gf/v2/container/gvar"
|
||||
"github.com/gogf/gf/v2/database/gdb"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/os/gtime"
|
||||
@ -1283,12 +1284,12 @@ func Test_Issue3754(t *testing.T) {
|
||||
func Test_Issue3626(t *testing.T) {
|
||||
table := "issue3626"
|
||||
array := gstr.SplitAndTrim(gtest.DataContent(`issue3626.sql`), ";")
|
||||
defer dropTable(table)
|
||||
for _, v := range array {
|
||||
if _, err := db.Exec(ctx, v); err != nil {
|
||||
gtest.Error(err)
|
||||
}
|
||||
}
|
||||
defer dropTable(table)
|
||||
|
||||
// Insert.
|
||||
gtest.C(t, func(t *gtest.T) {
|
||||
@ -1377,3 +1378,34 @@ func Test_Issue3932(t *testing.T) {
|
||||
t.Assert(one["id"], 10)
|
||||
})
|
||||
}
|
||||
|
||||
// https://github.com/gogf/gf/issues/3968
|
||||
func Test_Issue3968(t *testing.T) {
|
||||
table := createInitTable()
|
||||
defer dropTable(table)
|
||||
|
||||
gtest.C(t, func(t *gtest.T) {
|
||||
var hook = gdb.HookHandler{
|
||||
Select: func(ctx context.Context, in *gdb.HookSelectInput) (result gdb.Result, err error) {
|
||||
result, err = in.Next(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if result != nil {
|
||||
for i, _ := range result {
|
||||
result[i]["location"] = gvar.New("ny")
|
||||
}
|
||||
}
|
||||
return
|
||||
},
|
||||
}
|
||||
var (
|
||||
count int
|
||||
result gdb.Result
|
||||
)
|
||||
err := db.Model(table).Hook(hook).ScanAndCount(&result, &count, false)
|
||||
t.AssertNil(err)
|
||||
t.Assert(count, 10)
|
||||
t.Assert(len(result), 10)
|
||||
})
|
||||
}
|
||||
|
@ -396,12 +396,13 @@ const (
|
||||
linkPattern = `(\w+):([\w\-\$]*):(.*?)@(\w+?)\((.+?)\)/{0,1}([^\?]*)\?{0,1}(.*)`
|
||||
)
|
||||
|
||||
type queryType int
|
||||
type SelectType int
|
||||
|
||||
const (
|
||||
queryTypeNormal queryType = iota
|
||||
queryTypeCount
|
||||
queryTypeValue
|
||||
SelectTypeDefault SelectType = iota
|
||||
SelectTypeCount
|
||||
SelectTypeValue
|
||||
SelectTypeArray
|
||||
)
|
||||
|
||||
type joinOperator string
|
||||
@ -700,13 +701,13 @@ func getConfigNodeByWeight(cg ConfigGroup) *ConfigNode {
|
||||
}
|
||||
// Exclude the right border value.
|
||||
var (
|
||||
min = 0
|
||||
max = 0
|
||||
random = grand.N(0, total-1)
|
||||
minWeight = 0
|
||||
maxWeight = 0
|
||||
random = grand.N(0, total-1)
|
||||
)
|
||||
for i := 0; i < len(cg); i++ {
|
||||
max = min + cg[i].Weight*100
|
||||
if random >= min && random < max {
|
||||
maxWeight = minWeight + cg[i].Weight*100
|
||||
if random >= minWeight && random < maxWeight {
|
||||
// ====================================================
|
||||
// Return a COPY of the ConfigNode.
|
||||
// ====================================================
|
||||
@ -714,7 +715,7 @@ func getConfigNodeByWeight(cg ConfigGroup) *ConfigNode {
|
||||
node = cg[i]
|
||||
return &node
|
||||
}
|
||||
min = max
|
||||
minWeight = maxWeight
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -278,7 +278,7 @@ func (c *Core) doUnion(ctx context.Context, unionType int, unions ...*Model) *Mo
|
||||
unionTypeStr = "UNION"
|
||||
}
|
||||
for _, v := range unions {
|
||||
sqlWithHolder, holderArgs := v.getFormattedSqlAndArgs(ctx, queryTypeNormal, false)
|
||||
sqlWithHolder, holderArgs := v.getFormattedSqlAndArgs(ctx, SelectTypeDefault, false)
|
||||
if composedSqlStr == "" {
|
||||
composedSqlStr += fmt.Sprintf(`(%s)`, sqlWithHolder)
|
||||
} else {
|
||||
|
@ -23,8 +23,6 @@ type internalCtxData struct {
|
||||
}
|
||||
|
||||
// column stores column data in ctx for internal usage purpose.
|
||||
// Deprecated.
|
||||
// TODO remove this usage in future.
|
||||
type internalColumnData struct {
|
||||
// The first column in result response from database server.
|
||||
// This attribute is used for Value/Count selection statement purpose,
|
||||
|
@ -90,7 +90,7 @@ func (m *Model) getSelectResultFromCache(ctx context.Context, sql string, args .
|
||||
}
|
||||
|
||||
func (m *Model) saveSelectResultToCache(
|
||||
ctx context.Context, queryType queryType, result Result, sql string, args ...interface{},
|
||||
ctx context.Context, selectType SelectType, result Result, sql string, args ...interface{},
|
||||
) (err error) {
|
||||
if !m.cacheEnabled || m.tx != nil {
|
||||
return
|
||||
@ -108,18 +108,19 @@ func (m *Model) saveSelectResultToCache(
|
||||
// Special handler for Value/Count operations result.
|
||||
if len(result) > 0 {
|
||||
var core = m.db.GetCore()
|
||||
switch queryType {
|
||||
case queryTypeValue, queryTypeCount:
|
||||
switch selectType {
|
||||
case SelectTypeValue, SelectTypeArray, SelectTypeCount:
|
||||
if internalData := core.getInternalColumnFromCtx(ctx); internalData != nil {
|
||||
if result[0][internalData.FirstResultColumn].IsEmpty() {
|
||||
result = nil
|
||||
}
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// In case of Cache Penetration.
|
||||
if result.IsEmpty() {
|
||||
if result != nil && result.IsEmpty() {
|
||||
if m.cacheOption.Force {
|
||||
result = Result{}
|
||||
} else {
|
||||
|
@ -66,11 +66,12 @@ type internalParamHookDelete struct {
|
||||
// which is usually not be interesting for upper business hook handler.
|
||||
type HookSelectInput struct {
|
||||
internalParamHookSelect
|
||||
Model *Model // Current operation Model.
|
||||
Table string // The table name that to be used. Update this attribute to change target table name.
|
||||
Schema string // The schema name that to be used. Update this attribute to change target schema name.
|
||||
Sql string // The sql string that to be committed.
|
||||
Args []interface{} // The arguments of sql.
|
||||
Model *Model // Current operation Model.
|
||||
Table string // The table name that to be used. Update this attribute to change target table name.
|
||||
Schema string // The schema name that to be used. Update this attribute to change target schema name.
|
||||
Sql string // The sql string that to be committed.
|
||||
Args []interface{} // The arguments of sql.
|
||||
SelectType SelectType // The type of this SELECT operation.
|
||||
}
|
||||
|
||||
// HookInsertInput holds the parameters for insert hook operation.
|
||||
|
@ -28,7 +28,7 @@ import (
|
||||
// see Model.Where.
|
||||
func (m *Model) All(where ...interface{}) (Result, error) {
|
||||
var ctx = m.GetCtx()
|
||||
return m.doGetAll(ctx, false, where...)
|
||||
return m.doGetAll(ctx, SelectTypeDefault, false, where...)
|
||||
}
|
||||
|
||||
// AllAndCount retrieves all records and the total count of records from the model.
|
||||
@ -69,7 +69,7 @@ func (m *Model) AllAndCount(useFieldForCount bool) (result Result, totalCount in
|
||||
}
|
||||
|
||||
// Retrieve all records
|
||||
result, err = m.doGetAll(m.GetCtx(), false)
|
||||
result, err = m.doGetAll(m.GetCtx(), SelectTypeDefault, false)
|
||||
return
|
||||
}
|
||||
|
||||
@ -110,7 +110,7 @@ func (m *Model) One(where ...interface{}) (Record, error) {
|
||||
if len(where) > 0 {
|
||||
return m.Where(where[0], where[1:]...).One()
|
||||
}
|
||||
all, err := m.doGetAll(ctx, true)
|
||||
all, err := m.doGetAll(ctx, SelectTypeDefault, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -136,24 +136,41 @@ func (m *Model) Array(fieldsAndWhere ...interface{}) ([]Value, error) {
|
||||
return m.Fields(gconv.String(fieldsAndWhere[0])).Array()
|
||||
}
|
||||
}
|
||||
all, err := m.All()
|
||||
|
||||
var (
|
||||
field string
|
||||
core = m.db.GetCore()
|
||||
ctx = core.injectInternalColumn(m.GetCtx())
|
||||
)
|
||||
all, err := m.doGetAll(ctx, SelectTypeArray, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var field string
|
||||
if len(all) > 0 {
|
||||
var recordFields = m.getRecordFields(all[0])
|
||||
if len(recordFields) > 1 {
|
||||
// it returns error if there are multiple fields in the result record.
|
||||
return nil, gerror.NewCodef(
|
||||
gcode.CodeInvalidParameter,
|
||||
`invalid fields for "Array" operation, result fields number "%d"%s, but expect one`,
|
||||
len(recordFields),
|
||||
gjson.MustEncodeString(recordFields),
|
||||
internalData := core.getInternalColumnFromCtx(ctx)
|
||||
if internalData == nil {
|
||||
return nil, gerror.NewCode(
|
||||
gcode.CodeInternalError,
|
||||
`query count error: the internal context data is missing. there's internal issue should be fixed`,
|
||||
)
|
||||
}
|
||||
if len(recordFields) == 1 {
|
||||
field = recordFields[0]
|
||||
// If FirstResultColumn present, it returns the value of the first record of the first field.
|
||||
// It means it use no cache mechanism, while cache mechanism makes `internalData` missing.
|
||||
field = internalData.FirstResultColumn
|
||||
if field == "" {
|
||||
// Fields number check.
|
||||
var recordFields = m.getRecordFields(all[0])
|
||||
if len(recordFields) == 1 {
|
||||
field = recordFields[0]
|
||||
} else {
|
||||
// it returns error if there are multiple fields in the result record.
|
||||
return nil, gerror.NewCodef(
|
||||
gcode.CodeInvalidParameter,
|
||||
`invalid fields for "Array" operation, result fields number "%d"%s, but expect one`,
|
||||
len(recordFields),
|
||||
gjson.MustEncodeString(recordFields),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
return all.Array(field), nil
|
||||
@ -398,13 +415,26 @@ func (m *Model) Value(fieldsAndWhere ...interface{}) (Value, error) {
|
||||
}
|
||||
}
|
||||
var (
|
||||
sqlWithHolder, holderArgs = m.getFormattedSqlAndArgs(ctx, queryTypeValue, true)
|
||||
all, err = m.doGetAllBySql(ctx, queryTypeValue, sqlWithHolder, holderArgs...)
|
||||
sqlWithHolder, holderArgs = m.getFormattedSqlAndArgs(ctx, SelectTypeValue, true)
|
||||
all, err = m.doGetAllBySql(ctx, SelectTypeValue, sqlWithHolder, holderArgs...)
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(all) > 0 {
|
||||
internalData := core.getInternalColumnFromCtx(ctx)
|
||||
if internalData == nil {
|
||||
return nil, gerror.NewCode(
|
||||
gcode.CodeInternalError,
|
||||
`query count error: the internal context data is missing. there's internal issue should be fixed`,
|
||||
)
|
||||
}
|
||||
// If FirstResultColumn present, it returns the value of the first record of the first field.
|
||||
// It means it use no cache mechanism, while cache mechanism makes `internalData` missing.
|
||||
if v, ok := all[0][internalData.FirstResultColumn]; ok {
|
||||
return v, nil
|
||||
}
|
||||
// Fields number check.
|
||||
var recordFields = m.getRecordFields(all[0])
|
||||
if len(recordFields) == 1 {
|
||||
for _, v := range all[0] {
|
||||
@ -445,13 +475,26 @@ func (m *Model) Count(where ...interface{}) (int, error) {
|
||||
return m.Where(where[0], where[1:]...).Count()
|
||||
}
|
||||
var (
|
||||
sqlWithHolder, holderArgs = m.getFormattedSqlAndArgs(ctx, queryTypeCount, false)
|
||||
all, err = m.doGetAllBySql(ctx, queryTypeCount, sqlWithHolder, holderArgs...)
|
||||
sqlWithHolder, holderArgs = m.getFormattedSqlAndArgs(ctx, SelectTypeCount, false)
|
||||
all, err = m.doGetAllBySql(ctx, SelectTypeCount, sqlWithHolder, holderArgs...)
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len(all) > 0 {
|
||||
internalData := core.getInternalColumnFromCtx(ctx)
|
||||
if internalData == nil {
|
||||
return 0, gerror.NewCode(
|
||||
gcode.CodeInternalError,
|
||||
`query count error: the internal context data is missing. there's internal issue should be fixed`,
|
||||
)
|
||||
}
|
||||
// If FirstResultColumn present, it returns the value of the first record of the first field.
|
||||
// It means it use no cache mechanism, while cache mechanism makes `internalData` missing.
|
||||
if v, ok := all[0][internalData.FirstResultColumn]; ok {
|
||||
return v.Int(), nil
|
||||
}
|
||||
// Fields number check.
|
||||
var recordFields = m.getRecordFields(all[0])
|
||||
if len(recordFields) == 1 {
|
||||
for _, v := range all[0] {
|
||||
@ -616,17 +659,17 @@ func (m *Model) Having(having interface{}, args ...interface{}) *Model {
|
||||
// The parameter `limit1` specifies whether limits querying only one record if m.limit is not set.
|
||||
// The optional parameter `where` is the same as the parameter of Model.Where function,
|
||||
// see Model.Where.
|
||||
func (m *Model) doGetAll(ctx context.Context, limit1 bool, where ...interface{}) (Result, error) {
|
||||
func (m *Model) doGetAll(ctx context.Context, selectType SelectType, limit1 bool, where ...interface{}) (Result, error) {
|
||||
if len(where) > 0 {
|
||||
return m.Where(where[0], where[1:]...).All()
|
||||
}
|
||||
sqlWithHolder, holderArgs := m.getFormattedSqlAndArgs(ctx, queryTypeNormal, limit1)
|
||||
return m.doGetAllBySql(ctx, queryTypeNormal, sqlWithHolder, holderArgs...)
|
||||
sqlWithHolder, holderArgs := m.getFormattedSqlAndArgs(ctx, selectType, limit1)
|
||||
return m.doGetAllBySql(ctx, selectType, sqlWithHolder, holderArgs...)
|
||||
}
|
||||
|
||||
// doGetAllBySql does the select statement on the database.
|
||||
func (m *Model) doGetAllBySql(
|
||||
ctx context.Context, queryType queryType, sql string, args ...interface{},
|
||||
ctx context.Context, selectType SelectType, sql string, args ...interface{},
|
||||
) (result Result, err error) {
|
||||
if result, err = m.getSelectResultFromCache(ctx, sql, args...); err != nil || result != nil {
|
||||
return
|
||||
@ -639,24 +682,25 @@ func (m *Model) doGetAllBySql(
|
||||
},
|
||||
handler: m.hookHandler.Select,
|
||||
},
|
||||
Model: m,
|
||||
Table: m.tables,
|
||||
Sql: sql,
|
||||
Args: m.mergeArguments(args),
|
||||
Model: m,
|
||||
Table: m.tables,
|
||||
Sql: sql,
|
||||
Args: m.mergeArguments(args),
|
||||
SelectType: selectType,
|
||||
}
|
||||
if result, err = in.Next(ctx); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = m.saveSelectResultToCache(ctx, queryType, result, sql, args...)
|
||||
err = m.saveSelectResultToCache(ctx, selectType, result, sql, args...)
|
||||
return
|
||||
}
|
||||
|
||||
func (m *Model) getFormattedSqlAndArgs(
|
||||
ctx context.Context, queryType queryType, limit1 bool,
|
||||
ctx context.Context, selectType SelectType, limit1 bool,
|
||||
) (sqlWithHolder string, holderArgs []interface{}) {
|
||||
switch queryType {
|
||||
case queryTypeCount:
|
||||
switch selectType {
|
||||
case SelectTypeCount:
|
||||
queryFields := "COUNT(1)"
|
||||
if len(m.fields) > 0 {
|
||||
// DO NOT quote the m.fields here, in case of fields like:
|
||||
@ -698,7 +742,7 @@ func (m *Model) getFormattedSqlAndArgs(
|
||||
|
||||
func (m *Model) getHolderAndArgsAsSubModel(ctx context.Context) (holder string, args []interface{}) {
|
||||
holder, args = m.getFormattedSqlAndArgs(
|
||||
ctx, queryTypeNormal, false,
|
||||
ctx, SelectTypeDefault, false,
|
||||
)
|
||||
args = m.mergeArguments(args)
|
||||
return
|
||||
|
Loading…
x
Reference in New Issue
Block a user