1
0
mirror of https://github.com/gogf/gf.git synced 2025-04-05 11:18:50 +08:00

fix: #3238 first column might be overwritten in interal context data in multiple goroutines querying (#3476)

This commit is contained in:
John Guo 2024-04-16 19:31:06 +08:00 committed by GitHub
parent 75763735c4
commit bbcf49db98
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 114 additions and 36 deletions

View File

@ -9,6 +9,7 @@ package mysql_test
import (
"context"
"fmt"
"sync"
"testing"
"time"
@ -1125,3 +1126,37 @@ func Test_Issue2643(t *testing.T) {
t.Assert(gstr.Contains(sqlContent, expectKey2), true)
})
}
// https://github.com/gogf/gf/issues/3238
func Test_Issue3238(t *testing.T) {
table := createInitTable()
defer dropTable(table)
gtest.C(t, func(t *gtest.T) {
for i := 0; i < 100; i++ {
_, err := db.Ctx(ctx).Model(table).Hook(gdb.HookHandler{
Select: func(ctx context.Context, in *gdb.HookSelectInput) (result gdb.Result, err error) {
result, err = in.Next(ctx)
if err != nil {
return
}
var wg sync.WaitGroup
for _, record := range result {
wg.Add(1)
go func(record gdb.Record) {
defer wg.Done()
id, _ := db.Ctx(ctx).Model(table).WherePri(1).Value(`id`)
nickname, _ := db.Ctx(ctx).Model(table).WherePri(1).Value(`nickname`)
t.Assert(id.Int(), 1)
t.Assert(nickname.String(), "name_1")
}(record)
}
wg.Wait()
return
},
},
).All()
t.AssertNil(err)
}
})
}

View File

@ -744,10 +744,10 @@ func (c *Core) getSqlDb(master bool, schema ...string) (sqlDb *sql.DB, err error
node.Name = nodeSchema
}
// Update the configuration object in internal data.
internalData := c.GetInternalCtxDataFromCtx(ctx)
if internalData != nil {
internalData.ConfigNode = node
if err = c.setConfigNodeToCtx(ctx, node); err != nil {
return
}
// Cache the underlying connection pool object by node.
var (
instanceCacheFunc = func() interface{} {

View File

@ -56,7 +56,7 @@ func (c *Core) Ctx(ctx context.Context) DB {
panic(err)
}
newCore.ctx = WithDB(ctx, newCore.db)
newCore.ctx = c.InjectInternalCtxData(newCore.ctx)
newCore.ctx = c.injectInternalCtxData(newCore.ctx)
return newCore.db
}
@ -67,7 +67,7 @@ func (c *Core) GetCtx() context.Context {
if ctx == nil {
ctx = context.TODO()
}
return c.InjectInternalCtxData(ctx)
return c.injectInternalCtxData(ctx)
}
// GetCtxTimeout returns the context and cancel function for specified timeout type.

View File

@ -208,15 +208,15 @@ func (c *Core) SetMaxConnLifeTime(d time.Duration) {
// GetConfig returns the current used node configuration.
func (c *Core) GetConfig() *ConfigNode {
internalData := c.GetInternalCtxDataFromCtx(c.db.GetCtx())
if internalData != nil && internalData.ConfigNode != nil {
var configNode = c.getConfigNodeFromCtx(c.db.GetCtx())
if configNode != nil {
// Note:
// It so here checks and returns the config from current DB,
// if different schemas between current DB and config.Name from context,
// for example, in nested transaction scenario, the context is passed all through the logic procedure,
// but the config.Name from context may be still the original one from the first transaction object.
if c.config.Name == internalData.ConfigNode.Name {
return internalData.ConfigNode
if c.config.Name == configNode.Name {
return configNode
}
}
return c.config

View File

@ -8,18 +8,22 @@ package gdb
import (
"context"
"sync"
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/os/gctx"
)
// internalCtxData stores data in ctx for internal usage purpose.
type internalCtxData struct {
// Operation DB.
DB DB
sync.Mutex
// Used configuration node in current operation.
ConfigNode *ConfigNode
}
// column stores column data in ctx for internal usage purpose.
type internalColumnData struct {
// The first column in result response from database server.
// This attribute is used for Value/Count selection statement purpose,
// which is to avoid HOOK handler that might modify the result columns
@ -28,7 +32,8 @@ type internalCtxData struct {
}
const (
internalCtxDataKeyInCtx gctx.StrKey = "InternalCtxData"
internalCtxDataKeyInCtx gctx.StrKey = "InternalCtxData"
internalColumnDataKeyInCtx gctx.StrKey = "InternalColumnData"
// `ignoreResultKeyInCtx` is a mark for some db drivers that do not support `RowsAffected` function,
// for example: `clickhouse`. The `clickhouse` does not support fetching insert/update results,
@ -37,20 +42,46 @@ const (
ignoreResultKeyInCtx gctx.StrKey = "IgnoreResult"
)
func (c *Core) InjectInternalCtxData(ctx context.Context) context.Context {
func (c *Core) injectInternalCtxData(ctx context.Context) context.Context {
// If the internal data is already injected, it does nothing.
if ctx.Value(internalCtxDataKeyInCtx) != nil {
return ctx
}
return context.WithValue(ctx, internalCtxDataKeyInCtx, &internalCtxData{
DB: c.db,
ConfigNode: c.config,
})
}
func (c *Core) GetInternalCtxDataFromCtx(ctx context.Context) *internalCtxData {
if v := ctx.Value(internalCtxDataKeyInCtx); v != nil {
return v.(*internalCtxData)
func (c *Core) setConfigNodeToCtx(ctx context.Context, node *ConfigNode) error {
value := ctx.Value(internalCtxDataKeyInCtx)
if value == nil {
return gerror.NewCode(gcode.CodeInternalError, `no internal data found in context`)
}
data := value.(*internalCtxData)
data.Lock()
defer data.Unlock()
data.ConfigNode = node
return nil
}
func (c *Core) getConfigNodeFromCtx(ctx context.Context) *ConfigNode {
if value := ctx.Value(internalCtxDataKeyInCtx); value != nil {
data := value.(*internalCtxData)
data.Lock()
defer data.Unlock()
return data.ConfigNode
}
return nil
}
func (c *Core) injectInternalColumn(ctx context.Context) context.Context {
return context.WithValue(ctx, internalColumnDataKeyInCtx, &internalColumnData{})
}
func (c *Core) getInternalColumnFromCtx(ctx context.Context) *internalColumnData {
if v := ctx.Value(internalColumnDataKeyInCtx); v != nil {
return v.(*internalColumnData)
}
return nil
}

View File

@ -72,7 +72,7 @@ func (c *Core) Transaction(ctx context.Context, f func(ctx context.Context, tx T
if ctx == nil {
ctx = c.db.GetCtx()
}
ctx = c.InjectInternalCtxData(ctx)
ctx = c.injectInternalCtxData(ctx)
// Check transaction object from context.
var tx TX
tx = TXFromCtx(ctx, c.db.GetGroup())
@ -160,7 +160,7 @@ func (tx *TXCore) transactionKeyForNestedPoint() string {
func (tx *TXCore) Ctx(ctx context.Context) TX {
tx.ctx = ctx
if tx.ctx != nil {
tx.ctx = tx.db.GetCore().InjectInternalCtxData(tx.ctx)
tx.ctx = tx.db.GetCore().injectInternalCtxData(tx.ctx)
}
return tx
}

View File

@ -156,9 +156,6 @@ func (c *Core) DoFilter(ctx context.Context, link Link, sql string, args []inter
// DoCommit commits current sql and arguments to underlying sql driver.
func (c *Core) DoCommit(ctx context.Context, in DoCommitInput) (out DoCommitOutput, err error) {
// Inject internal data into ctx, especially for transaction creating.
ctx = c.InjectInternalCtxData(ctx)
var (
sqlTx *sql.Tx
sqlStmt *sql.Stmt
@ -420,7 +417,7 @@ func (c *Core) RowsToResult(ctx context.Context, rows *sql.Rows) (Result, error)
}
if len(columnTypes) > 0 {
if internalData := c.GetInternalCtxDataFromCtx(ctx); internalData != nil {
if internalData := c.getInternalColumnFromCtx(ctx); internalData != nil {
internalData.FirstResultColumn = columnTypes[0].Name()
}
}

View File

@ -69,10 +69,11 @@ func (m *Model) getSelectResultFromCache(ctx context.Context, sql string, args .
cacheItem *selectCacheItem
cacheKey = m.makeSelectCacheKey(sql, args...)
cacheObj = m.db.GetCache()
core = m.db.GetCore()
)
defer func() {
if cacheItem != nil {
if internalData := m.db.GetCore().GetInternalCtxDataFromCtx(ctx); internalData != nil {
if internalData := core.getInternalColumnFromCtx(ctx); internalData != nil {
if cacheItem.FirstResultColumn != "" {
internalData.FirstResultColumn = cacheItem.FirstResultColumn
}
@ -106,9 +107,10 @@ func (m *Model) saveSelectResultToCache(
}
// Special handler for Value/Count operations result.
if len(result) > 0 {
var core = m.db.GetCore()
switch queryType {
case queryTypeValue, queryTypeCount:
if internalData := m.db.GetCore().GetInternalCtxDataFromCtx(ctx); internalData != nil {
if internalData := core.getInternalColumnFromCtx(ctx); internalData != nil {
if result[0][internalData.FirstResultColumn].IsEmpty() {
result = nil
}
@ -124,10 +126,13 @@ func (m *Model) saveSelectResultToCache(
result = nil
}
}
var cacheItem = &selectCacheItem{
Result: result,
}
if internalData := m.db.GetCore().GetInternalCtxDataFromCtx(ctx); internalData != nil {
var (
core = m.db.GetCore()
cacheItem = &selectCacheItem{
Result: result,
}
)
if internalData := core.getInternalColumnFromCtx(ctx); internalData != nil {
cacheItem.FirstResultColumn = internalData.FirstResultColumn
}
if errCache := cacheObj.Set(ctx, cacheKey, cacheItem, m.cacheOption.Duration); errCache != nil {

View File

@ -139,9 +139,13 @@ func (m *Model) Array(fieldsAndWhere ...interface{}) ([]Value, error) {
if err != nil {
return nil, err
}
var field string
var (
field string
core = m.db.GetCore()
ctx = core.injectInternalColumn(m.GetCtx())
)
if len(all) > 0 {
if internalData := m.db.GetCore().GetInternalCtxDataFromCtx(m.GetCtx()); internalData != nil {
if internalData := core.getInternalColumnFromCtx(ctx); internalData != nil {
field = internalData.FirstResultColumn
} else {
return nil, gerror.NewCode(
@ -376,7 +380,10 @@ func (m *Model) ScanList(structSlicePointer interface{}, bindToAttrName string,
// and fieldsAndWhere[1:] is treated as where condition fields.
// Also see Model.Fields and Model.Where functions.
func (m *Model) Value(fieldsAndWhere ...interface{}) (Value, error) {
var ctx = m.GetCtx()
var (
core = m.db.GetCore()
ctx = core.injectInternalColumn(m.GetCtx())
)
if len(fieldsAndWhere) > 0 {
if len(fieldsAndWhere) > 2 {
return m.Fields(gconv.String(fieldsAndWhere[0])).Where(fieldsAndWhere[1], fieldsAndWhere[2:]...).Value()
@ -394,7 +401,7 @@ func (m *Model) Value(fieldsAndWhere ...interface{}) (Value, error) {
return nil, err
}
if len(all) > 0 {
if internalData := m.db.GetCore().GetInternalCtxDataFromCtx(ctx); internalData != nil {
if internalData := core.getInternalColumnFromCtx(ctx); internalData != nil {
if v, ok := all[0][internalData.FirstResultColumn]; ok {
return v, nil
}
@ -412,7 +419,10 @@ func (m *Model) Value(fieldsAndWhere ...interface{}) (Value, error) {
// The optional parameter `where` is the same as the parameter of Model.Where function,
// see Model.Where.
func (m *Model) Count(where ...interface{}) (int, error) {
var ctx = m.GetCtx()
var (
core = m.db.GetCore()
ctx = core.injectInternalColumn(m.GetCtx())
)
if len(where) > 0 {
return m.Where(where[0], where[1:]...).Count()
}
@ -424,7 +434,7 @@ func (m *Model) Count(where ...interface{}) (int, error) {
return 0, err
}
if len(all) > 0 {
if internalData := m.db.GetCore().GetInternalCtxDataFromCtx(ctx); internalData != nil {
if internalData := core.getInternalColumnFromCtx(ctx); internalData != nil {
if v, ok := all[0][internalData.FirstResultColumn]; ok {
return v.Int(), nil
}