1
0
mirror of https://github.com/gogf/gf.git synced 2025-04-05 11:18:50 +08:00

imrove context handling for package gdb

This commit is contained in:
John Guo 2022-03-24 17:51:49 +08:00
parent cc01629b57
commit 09ba1bf1fb
13 changed files with 106 additions and 89 deletions

View File

@ -522,8 +522,11 @@ func getConfigNodeByWeight(cg ConfigGroup) *ConfigNode {
// The parameter `master` specifies whether retrieves master node connection if
// master-slave nodes are configured.
func (c *Core) getSqlDb(master bool, schema ...string) (sqlDb *sql.DB, err error) {
var (
ctx = c.db.GetCtx()
node *ConfigNode
)
// Load balance.
var node *ConfigNode
if c.group != "" {
node, err = getConfigNodeByGroup(c.group, master)
if err != nil {
@ -549,20 +552,12 @@ func (c *Core) getSqlDb(master bool, schema ...string) (sqlDb *sql.DB, err error
}
// Cache the underlying connection pool object by node.
v := c.links.GetOrSetFuncLock(node.String(), func() interface{} {
intlog.Printf(
c.db.GetCtx(),
`open new connection, master:%#v, config:%#v, node:%#v`,
master, c.config, node,
)
intlog.Printf(ctx, `open new connection, master:%#v, config:%#v, node:%#v`, master, c.config, node)
defer func() {
if err != nil {
intlog.Printf(c.db.GetCtx(), `open new connection failed: %v, %#v`, err, node)
intlog.Printf(ctx, `open new connection failed: %v, %#v`, err, node)
} else {
intlog.Printf(
c.db.GetCtx(),
`open new connection success, master:%#v, config:%#v, node:%#v`,
master, c.config, node,
)
intlog.Printf(ctx, `open new connection success, master:%#v, config:%#v, node:%#v`, master, c.config, node)
}
}()

View File

@ -67,9 +67,9 @@ func (c *Core) GetCtx() context.Context {
}
// GetCtxTimeout returns the context and cancel function for specified timeout type.
func (c *Core) GetCtxTimeout(timeoutType int, ctx context.Context) (context.Context, context.CancelFunc) {
func (c *Core) GetCtxTimeout(ctx context.Context, timeoutType int) (context.Context, context.CancelFunc) {
if ctx == nil {
ctx = c.GetCtx()
ctx = c.db.GetCtx()
} else {
ctx = context.WithValue(ctx, "WrappedByGetCtxTimeout", nil)
}
@ -255,17 +255,18 @@ func (c *Core) GetCount(ctx context.Context, sql string, args ...interface{}) (i
// Union does "(SELECT xxx FROM xxx) UNION (SELECT xxx FROM xxx) ..." statement.
func (c *Core) Union(unions ...*Model) *Model {
return c.doUnion(unionTypeNormal, unions...)
var ctx = c.db.GetCtx()
return c.doUnion(ctx, unionTypeNormal, unions...)
}
// UnionAll does "(SELECT xxx FROM xxx) UNION ALL (SELECT xxx FROM xxx) ..." statement.
func (c *Core) UnionAll(unions ...*Model) *Model {
return c.doUnion(unionTypeAll, unions...)
var ctx = c.db.GetCtx()
return c.doUnion(ctx, unionTypeAll, unions...)
}
func (c *Core) doUnion(unionType int, unions ...*Model) *Model {
func (c *Core) doUnion(ctx context.Context, unionType int, unions ...*Model) *Model {
var (
ctx = c.db.GetCtx()
unionTypeStr string
composedSqlStr string
composedArgs = make([]interface{}, 0)
@ -289,10 +290,11 @@ func (c *Core) doUnion(unionType int, unions ...*Model) *Model {
// PingMaster pings the master node to check authentication or keeps the connection alive.
func (c *Core) PingMaster() error {
var ctx = c.db.GetCtx()
if master, err := c.db.Master(); err != nil {
return err
} else {
if err = master.PingContext(c.GetCtx()); err != nil {
if err = master.PingContext(ctx); err != nil {
err = gerror.WrapCode(gcode.CodeDbOperationError, err, `master.Ping failed`)
}
return err
@ -301,10 +303,11 @@ func (c *Core) PingMaster() error {
// PingSlave pings the slave node to check authentication or keeps the connection alive.
func (c *Core) PingSlave() error {
var ctx = c.db.GetCtx()
if slave, err := c.db.Slave(); err != nil {
return err
} else {
if err = slave.PingContext(c.GetCtx()); err != nil {
if err = slave.PingContext(ctx); err != nil {
err = gerror.WrapCode(gcode.CodeDbOperationError, err, `slave.Ping failed`)
}
return err
@ -663,21 +666,22 @@ func (c *Core) writeSqlToLogger(ctx context.Context, sql *Sql) {
// HasTable determine whether the table name exists in the database.
func (c *Core) HasTable(name string) (bool, error) {
result, err := c.GetCache().GetOrSetFuncLock(
c.GetCtx(),
fmt.Sprintf(`HasTable: %s`, name),
func(ctx context.Context) (interface{}, error) {
tableList, err := c.db.Tables(ctx)
if err != nil {
return false, err
var (
ctx = c.db.GetCtx()
cacheKey = fmt.Sprintf(`HasTable: %s`, name)
)
result, err := c.GetCache().GetOrSetFuncLock(ctx, cacheKey, func(ctx context.Context) (interface{}, error) {
tableList, err := c.db.Tables(ctx)
if err != nil {
return false, err
}
for _, table := range tableList {
if table == name {
return true, nil
}
for _, table := range tableList {
if table == name {
return true, nil
}
}
return false, nil
}, 0,
}
return false, nil
}, 0,
)
if err != nil {
return false, err

View File

@ -71,7 +71,7 @@ func (c *Core) doBeginCtx(ctx context.Context) (*TX, error) {
func (c *Core) Transaction(ctx context.Context, f func(ctx context.Context, tx *TX) error) (err error) {
var tx *TX
if ctx == nil {
ctx = c.GetCtx()
ctx = c.db.GetCtx()
}
// Check transaction object from context.
tx = TXFromCtx(ctx, c.db.GetGroup())

View File

@ -234,7 +234,7 @@ func (c *Core) DoCommit(ctx context.Context, in DoCommitInput) (out DoCommitOutp
out.RawResult = sqlStmt
case SqlTypeStmtExecContext:
ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctxTimeoutTypeExec, ctx)
ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctx, ctxTimeoutTypeExec)
defer cancelFuncForTimeout()
if c.db.GetDryRun() {
sqlResult = new(SqlResult)
@ -244,13 +244,13 @@ func (c *Core) DoCommit(ctx context.Context, in DoCommitInput) (out DoCommitOutp
out.RawResult = sqlResult
case SqlTypeStmtQueryContext:
ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctxTimeoutTypeQuery, ctx)
ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctx, ctxTimeoutTypeQuery)
defer cancelFuncForTimeout()
stmtSqlRows, err = in.Stmt.QueryContext(ctx, in.Args...)
out.RawResult = stmtSqlRows
case SqlTypeStmtQueryRowContext:
ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctxTimeoutTypeQuery, ctx)
ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctx, ctxTimeoutTypeQuery)
defer cancelFuncForTimeout()
stmtSqlRow = in.Stmt.QueryRowContext(ctx, in.Args...)
out.RawResult = stmtSqlRow

View File

@ -150,11 +150,12 @@ func (c *Core) Tables(schema ...string) (tables []string, err error) {
//
// It does nothing in default.
func (c *Core) TableFields(table string, schema ...string) (fields map[string]*TableField, err error) {
var ctx = c.db.GetCtx()
// It does nothing if given table is empty, especially in sub-query.
if table == "" {
return map[string]*TableField{}, nil
}
return c.db.TableFields(c.GetCtx(), table, schema...)
return c.db.TableFields(ctx, table, schema...)
}
// HasField determine whether the field exists in the table.

View File

@ -38,6 +38,7 @@ func (d *DriverMysql) New(core *Core, node *ConfigNode) (DB, error) {
// Note that it converts time.Time argument to local timezone in default.
func (d *DriverMysql) Open(config *ConfigNode) (db *sql.DB, err error) {
var (
ctx = d.GetCtx()
source string
underlyingDriverName = "mysql"
)
@ -56,7 +57,7 @@ func (d *DriverMysql) Open(config *ConfigNode) (db *sql.DB, err error) {
source = fmt.Sprintf("%s&loc=%s", source, url.QueryEscape(config.Timezone))
}
}
intlog.Printf(d.GetCtx(), "Open: %s", source)
intlog.Printf(ctx, "Open: %s", source)
if db, err = sql.Open(underlyingDriverName, source); err != nil {
err = gerror.WrapCodef(
gcode.CodeDbOperationError, err,

View File

@ -8,6 +8,7 @@ package gdb
import (
"bytes"
"context"
"fmt"
"reflect"
"regexp"
@ -364,7 +365,7 @@ func isKeyValueCanBeOmitEmpty(omitEmpty bool, whereType string, key, value inter
}
// formatWhereHolder formats where statement and its arguments for `Where` and `Having` statements.
func formatWhereHolder(db DB, in formatWhereHolderInput) (newWhere string, newArgs []interface{}) {
func formatWhereHolder(ctx context.Context, db DB, in formatWhereHolderInput) (newWhere string, newArgs []interface{}) {
var (
buffer = bytes.NewBuffer(nil)
reflectInfo = reflection.OriginValueAndKind(in.Where)
@ -393,7 +394,7 @@ func formatWhereHolder(db DB, in formatWhereHolderInput) (newWhere string, newAr
}
case reflect.Struct:
// If the `where` parameter is DO struct, it then adds `OmitNil` option for this condition,
// If the `where` parameter is `DO` struct, it then adds `OmitNil` option for this condition,
// which will filter all nil parameters in `where`.
if isDoStruct(in.Where) {
in.OmitNil = true
@ -523,7 +524,9 @@ func formatWhereHolder(db DB, in formatWhereHolderInput) (newWhere string, newAr
whereStr, _ = gregex.ReplaceStringFunc(`(\?)`, whereStr, func(s string) string {
index++
if i+len(newArgs) == index {
sqlWithHolder, holderArgs := model.getFormattedSqlAndArgs(model.GetCtx(), queryTypeNormal, false)
sqlWithHolder, holderArgs := model.getFormattedSqlAndArgs(
ctx, queryTypeNormal, false,
)
newArgs = append(newArgs, holderArgs...)
// Automatically adding the brackets.
return "(" + sqlWithHolder + ")"

View File

@ -93,6 +93,7 @@ const (
// db.Model("? AS a, ? AS b", subQuery1, subQuery2)
func (c *Core) Model(tableNameQueryOrStruct ...interface{}) *Model {
var (
ctx = c.db.GetCtx()
tableStr string
tableName string
extraArgs []interface{}
@ -105,7 +106,7 @@ func (c *Core) Model(tableNameQueryOrStruct ...interface{}) *Model {
Where: conditionStr,
Args: tableNameQueryOrStruct[1:],
}
tableStr, extraArgs = formatWhereHolder(c.db, formatWhereHolderInput{
tableStr, extraArgs = formatWhereHolder(ctx, c.db, formatWhereHolderInput{
ModelWhereHolder: whereHolder,
OmitNil: false,
OmitEmpty: false,

View File

@ -7,6 +7,7 @@
package gdb
import (
"context"
"time"
"github.com/gogf/gf/v2/internal/intlog"
@ -44,11 +45,9 @@ func (m *Model) Cache(option CacheOption) *Model {
// checkAndRemoveCache checks and removes the cache in insert/update/delete statement if
// cache feature is enabled.
func (m *Model) checkAndRemoveCache() {
func (m *Model) checkAndRemoveCache(ctx context.Context) {
if m.cacheEnabled && m.cacheOption.Duration < 0 && len(m.cacheOption.Name) > 0 {
ctx := m.GetCtx()
_, err := m.db.GetCache().Remove(ctx, m.cacheOption.Name)
if err != nil {
if _, err := m.db.GetCache().Remove(ctx, m.cacheOption.Name); err != nil {
intlog.Errorf(ctx, `%+v`, err)
}
}

View File

@ -20,17 +20,18 @@ import (
// The optional parameter `where` is the same as the parameter of Model.Where function,
// see Model.Where.
func (m *Model) Delete(where ...interface{}) (result sql.Result, err error) {
var ctx = m.GetCtx()
if len(where) > 0 {
return m.Where(where[0], where[1:]...).Delete()
}
defer func() {
if err == nil {
m.checkAndRemoveCache()
m.checkAndRemoveCache(ctx)
}
}()
var (
fieldNameDelete = m.getSoftFieldNameDeleted()
conditionWhere, conditionExtra, conditionArgs = m.formatCondition(false, false)
conditionWhere, conditionExtra, conditionArgs = m.formatCondition(ctx, false, false)
)
// Soft deleting.
if !m.unscoped && fieldNameDelete != "" {
@ -47,7 +48,7 @@ func (m *Model) Delete(where ...interface{}) (result sql.Result, err error) {
Condition: conditionWhere + conditionExtra,
Args: append([]interface{}{gtime.Now().String()}, conditionArgs...),
}
return in.Next(m.GetCtx())
return in.Next(ctx)
}
conditionStr := conditionWhere + conditionExtra
if !gstr.ContainsI(conditionStr, " WHERE ") {
@ -69,5 +70,5 @@ func (m *Model) Delete(where ...interface{}) (result sql.Result, err error) {
Condition: conditionStr,
Args: conditionArgs,
}
return in.Next(m.GetCtx())
return in.Next(ctx)
}

View File

@ -7,6 +7,7 @@
package gdb
import (
"context"
"database/sql"
"reflect"
@ -38,7 +39,10 @@ func (m *Model) Batch(batch int) *Model {
// Data(g.Map{"uid": 10000, "name":"john"})
// Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"}).
func (m *Model) Data(data ...interface{}) *Model {
model := m.getModel()
var (
ctx = m.GetCtx()
model = m.getModel()
)
if len(data) > 1 {
if s := gconv.String(data[0]); gstr.Contains(s, "?") {
model.data = s
@ -83,7 +87,7 @@ func (m *Model) Data(data ...interface{}) *Model {
}
list := make(List, reflectInfo.OriginValue.Len())
for i := 0; i < reflectInfo.OriginValue.Len(); i++ {
list[i] = m.db.ConvertDataForRecord(m.GetCtx(), reflectInfo.OriginValue.Index(i).Interface())
list[i] = m.db.ConvertDataForRecord(ctx, reflectInfo.OriginValue.Index(i).Interface())
}
model.data = list
@ -100,15 +104,15 @@ func (m *Model) Data(data ...interface{}) *Model {
list = make(List, len(array))
)
for i := 0; i < len(array); i++ {
list[i] = m.db.ConvertDataForRecord(m.GetCtx(), array[i])
list[i] = m.db.ConvertDataForRecord(ctx, array[i])
}
model.data = list
} else {
model.data = m.db.ConvertDataForRecord(m.GetCtx(), data[0])
model.data = m.db.ConvertDataForRecord(ctx, data[0])
}
case reflect.Map:
model.data = m.db.ConvertDataForRecord(m.GetCtx(), data[0])
model.data = m.db.ConvertDataForRecord(ctx, data[0])
default:
model.data = data[0]
@ -164,18 +168,20 @@ func (m *Model) OnDuplicateEx(onDuplicateEx ...interface{}) *Model {
// The optional parameter `data` is the same as the parameter of Model.Data function,
// see Model.Data.
func (m *Model) Insert(data ...interface{}) (result sql.Result, err error) {
var ctx = m.GetCtx()
if len(data) > 0 {
return m.Data(data...).Insert()
}
return m.doInsertWithOption(InsertOptionDefault)
return m.doInsertWithOption(ctx, InsertOptionDefault)
}
// InsertAndGetId performs action Insert and returns the last insert id that automatically generated.
func (m *Model) InsertAndGetId(data ...interface{}) (lastInsertId int64, err error) {
var ctx = m.GetCtx()
if len(data) > 0 {
return m.Data(data...).InsertAndGetId()
}
result, err := m.doInsertWithOption(InsertOptionDefault)
result, err := m.doInsertWithOption(ctx, InsertOptionDefault)
if err != nil {
return 0, err
}
@ -186,20 +192,22 @@ func (m *Model) InsertAndGetId(data ...interface{}) (lastInsertId int64, err err
// The optional parameter `data` is the same as the parameter of Model.Data function,
// see Model.Data.
func (m *Model) InsertIgnore(data ...interface{}) (result sql.Result, err error) {
var ctx = m.GetCtx()
if len(data) > 0 {
return m.Data(data...).InsertIgnore()
}
return m.doInsertWithOption(InsertOptionIgnore)
return m.doInsertWithOption(ctx, InsertOptionIgnore)
}
// Replace does "REPLACE INTO ..." statement for the model.
// The optional parameter `data` is the same as the parameter of Model.Data function,
// see Model.Data.
func (m *Model) Replace(data ...interface{}) (result sql.Result, err error) {
var ctx = m.GetCtx()
if len(data) > 0 {
return m.Data(data...).Replace()
}
return m.doInsertWithOption(InsertOptionReplace)
return m.doInsertWithOption(ctx, InsertOptionReplace)
}
// Save does "INSERT INTO ... ON DUPLICATE KEY UPDATE..." statement for the model.
@ -209,17 +217,18 @@ func (m *Model) Replace(data ...interface{}) (result sql.Result, err error) {
// It updates the record if there's primary or unique index in the saving data,
// or else it inserts a new record into the table.
func (m *Model) Save(data ...interface{}) (result sql.Result, err error) {
var ctx = m.GetCtx()
if len(data) > 0 {
return m.Data(data...).Save()
}
return m.doInsertWithOption(InsertOptionSave)
return m.doInsertWithOption(ctx, InsertOptionSave)
}
// doInsertWithOption inserts data with option parameter.
func (m *Model) doInsertWithOption(insertOption int) (result sql.Result, err error) {
func (m *Model) doInsertWithOption(ctx context.Context, insertOption int) (result sql.Result, err error) {
defer func() {
if err == nil {
m.checkAndRemoveCache()
m.checkAndRemoveCache(ctx)
}
}()
if m.data == nil {
@ -246,11 +255,11 @@ func (m *Model) doInsertWithOption(insertOption int) (result sql.Result, err err
case List:
list = value
for i, v := range list {
list[i] = m.db.ConvertDataForRecord(m.GetCtx(), v)
list[i] = m.db.ConvertDataForRecord(ctx, v)
}
case Map:
list = List{m.db.ConvertDataForRecord(m.GetCtx(), value)}
list = List{m.db.ConvertDataForRecord(ctx, value)}
default:
reflectInfo := reflection.OriginValueAndKind(newData)
@ -259,21 +268,21 @@ func (m *Model) doInsertWithOption(insertOption int) (result sql.Result, err err
case reflect.Slice, reflect.Array:
list = make(List, reflectInfo.OriginValue.Len())
for i := 0; i < reflectInfo.OriginValue.Len(); i++ {
list[i] = m.db.ConvertDataForRecord(m.GetCtx(), reflectInfo.OriginValue.Index(i).Interface())
list[i] = m.db.ConvertDataForRecord(ctx, reflectInfo.OriginValue.Index(i).Interface())
}
case reflect.Map:
list = List{m.db.ConvertDataForRecord(m.GetCtx(), value)}
list = List{m.db.ConvertDataForRecord(ctx, value)}
case reflect.Struct:
if v, ok := value.(iInterfaces); ok {
array := v.Interfaces()
list = make(List, len(array))
for i := 0; i < len(array); i++ {
list[i] = m.db.ConvertDataForRecord(m.GetCtx(), array[i])
list[i] = m.db.ConvertDataForRecord(ctx, array[i])
}
} else {
list = List{m.db.ConvertDataForRecord(m.GetCtx(), value)}
list = List{m.db.ConvertDataForRecord(ctx, value)}
}
default:
@ -323,7 +332,7 @@ func (m *Model) doInsertWithOption(insertOption int) (result sql.Result, err err
Data: list,
Option: doInsertOption,
}
return in.Next(m.GetCtx())
return in.Next(ctx)
}
func (m *Model) formatDoInsertOption(insertOption int, columnNames []string) (option DoInsertOption, err error) {

View File

@ -30,7 +30,8 @@ import (
// The optional parameter `where` is the same as the parameter of Model.Where function,
// see Model.Where.
func (m *Model) All(where ...interface{}) (Result, error) {
return m.doGetAll(m.GetCtx(), false, where...)
var ctx = m.GetCtx()
return m.doGetAll(ctx, false, where...)
}
// doGetAll does "SELECT FROM ..." statement for the model.
@ -44,7 +45,7 @@ func (m *Model) doGetAll(ctx context.Context, limit1 bool, where ...interface{})
if len(where) > 0 {
return m.Where(where[0], where[1:]...).All()
}
sqlWithHolder, holderArgs := m.getFormattedSqlAndArgs(m.GetCtx(), queryTypeNormal, limit1)
sqlWithHolder, holderArgs := m.getFormattedSqlAndArgs(ctx, queryTypeNormal, limit1)
return m.doGetAllBySql(ctx, queryTypeNormal, sqlWithHolder, holderArgs...)
}
@ -131,10 +132,11 @@ func (m *Model) Chunk(size int, handler ChunkHandler) {
// The optional parameter `where` is the same as the parameter of Model.Where function,
// see Model.Where.
func (m *Model) One(where ...interface{}) (Record, error) {
var ctx = m.GetCtx()
if len(where) > 0 {
return m.Where(where[0], where[1:]...).One()
}
all, err := m.doGetAll(m.GetCtx(), true)
all, err := m.doGetAll(ctx, true)
if err != nil {
return nil, err
}
@ -151,6 +153,7 @@ func (m *Model) One(where ...interface{}) (Record, error) {
// and fieldsAndWhere[1:] is treated as where condition fields.
// Also see Model.Fields and Model.Where functions.
func (m *Model) Value(fieldsAndWhere ...interface{}) (Value, error) {
var ctx = m.GetCtx()
if len(fieldsAndWhere) > 0 {
if len(fieldsAndWhere) > 2 {
return m.Fields(gconv.String(fieldsAndWhere[0])).Where(fieldsAndWhere[1], fieldsAndWhere[2:]...).Value()
@ -163,7 +166,6 @@ func (m *Model) Value(fieldsAndWhere ...interface{}) (Value, error) {
var (
all Result
err error
ctx = m.GetCtx()
)
if all, err = m.doGetAll(ctx, true); err != nil {
return nil, err
@ -373,11 +375,11 @@ func (m *Model) ScanList(structSlicePointer interface{}, bindToAttrName string,
// The optional parameter `where` is the same as the parameter of Model.Where function,
// see Model.Where.
func (m *Model) Count(where ...interface{}) (int, error) {
var ctx = m.GetCtx()
if len(where) > 0 {
return m.Where(where[0], where[1:]...).Count()
}
var (
ctx = m.GetCtx()
sqlWithHolder, holderArgs = m.getFormattedSqlAndArgs(ctx, queryTypeCount, false)
all, err = m.doGetAllBySql(ctx, queryTypeCount, sqlWithHolder, holderArgs...)
)
@ -566,7 +568,7 @@ func (m *Model) doGetAllBySql(ctx context.Context, queryType int, sql string, ar
if cacheKey != "" && err == nil {
if m.cacheOption.Duration < 0 {
if _, errCache := cacheObj.Remove(ctx, cacheKey); errCache != nil {
intlog.Errorf(m.GetCtx(), `%+v`, errCache)
intlog.Errorf(ctx, `%+v`, errCache)
}
} else {
// In case of Cache Penetration.
@ -574,7 +576,7 @@ func (m *Model) doGetAllBySql(ctx context.Context, queryType int, sql string, ar
result = Result{}
}
if errCache := cacheObj.Set(ctx, cacheKey, result, m.cacheOption.Duration); errCache != nil {
intlog.Errorf(m.GetCtx(), `%+v`, errCache)
intlog.Errorf(ctx, `%+v`, errCache)
}
}
}
@ -595,7 +597,7 @@ func (m *Model) getFormattedSqlAndArgs(ctx context.Context, queryType int, limit
sqlWithHolder = fmt.Sprintf("SELECT %s FROM (%s) AS T", queryFields, m.rawSql)
return sqlWithHolder, nil
}
conditionWhere, conditionExtra, conditionArgs := m.formatCondition(false, true)
conditionWhere, conditionExtra, conditionArgs := m.formatCondition(ctx, false, true)
sqlWithHolder = fmt.Sprintf("SELECT %s FROM %s%s", queryFields, m.tables, conditionWhere+conditionExtra)
if len(m.groupBy) > 0 {
sqlWithHolder = fmt.Sprintf("SELECT COUNT(1) FROM (%s) count_alias", sqlWithHolder)
@ -603,7 +605,7 @@ func (m *Model) getFormattedSqlAndArgs(ctx context.Context, queryType int, limit
return sqlWithHolder, conditionArgs
default:
conditionWhere, conditionExtra, conditionArgs := m.formatCondition(limit1, false)
conditionWhere, conditionExtra, conditionArgs := m.formatCondition(ctx, limit1, false)
// Raw SQL Model, especially for UNION/UNION ALL featured SQL.
if m.rawSql != "" {
sqlWithHolder = fmt.Sprintf(
@ -627,7 +629,7 @@ func (m *Model) getFormattedSqlAndArgs(ctx context.Context, queryType int, limit
// Note that this function does not change any attribute value of the `m`.
//
// The parameter `limit1` specifies whether limits querying only one record if m.limit is not set.
func (m *Model) formatCondition(limit1 bool, isCountStatement bool) (conditionWhere string, conditionExtra string, conditionArgs []interface{}) {
func (m *Model) formatCondition(ctx context.Context, limit1 bool, isCountStatement bool) (conditionWhere string, conditionExtra string, conditionArgs []interface{}) {
autoPrefix := ""
if gstr.Contains(m.tables, " JOIN ") {
autoPrefix = m.db.GetCore().QuoteWord(
@ -647,7 +649,7 @@ func (m *Model) formatCondition(limit1 bool, isCountStatement bool) (conditionWh
switch holder.Operator {
case whereHolderOperatorWhere:
if conditionWhere == "" {
newWhere, newArgs := formatWhereHolder(m.db, formatWhereHolderInput{
newWhere, newArgs := formatWhereHolder(ctx, m.db, formatWhereHolderInput{
ModelWhereHolder: holder,
OmitNil: m.option&optionOmitNilWhere > 0,
OmitEmpty: m.option&optionOmitEmptyWhere > 0,
@ -663,7 +665,7 @@ func (m *Model) formatCondition(limit1 bool, isCountStatement bool) (conditionWh
fallthrough
case whereHolderOperatorAnd:
newWhere, newArgs := formatWhereHolder(m.db, formatWhereHolderInput{
newWhere, newArgs := formatWhereHolder(ctx, m.db, formatWhereHolderInput{
ModelWhereHolder: holder,
OmitNil: m.option&optionOmitNilWhere > 0,
OmitEmpty: m.option&optionOmitEmptyWhere > 0,
@ -682,7 +684,7 @@ func (m *Model) formatCondition(limit1 bool, isCountStatement bool) (conditionWh
}
case whereHolderOperatorOr:
newWhere, newArgs := formatWhereHolder(m.db, formatWhereHolderInput{
newWhere, newArgs := formatWhereHolder(ctx, m.db, formatWhereHolderInput{
ModelWhereHolder: holder,
OmitNil: m.option&optionOmitNilWhere > 0,
OmitEmpty: m.option&optionOmitEmptyWhere > 0,
@ -733,7 +735,7 @@ func (m *Model) formatCondition(limit1 bool, isCountStatement bool) (conditionWh
Args: gconv.Interfaces(m.having[1]),
Prefix: autoPrefix,
}
havingStr, havingArgs := formatWhereHolder(m.db, formatWhereHolderInput{
havingStr, havingArgs := formatWhereHolder(ctx, m.db, formatWhereHolderInput{
ModelWhereHolder: havingHolder,
OmitNil: m.option&optionOmitNilWhere > 0,
OmitEmpty: m.option&optionOmitEmptyWhere > 0,

View File

@ -25,6 +25,7 @@ import (
// and dataAndWhere[1:] is treated as where condition fields.
// Also see Model.Data and Model.Where functions.
func (m *Model) Update(dataAndWhere ...interface{}) (result sql.Result, err error) {
var ctx = m.GetCtx()
if len(dataAndWhere) > 0 {
if len(dataAndWhere) > 2 {
return m.Data(dataAndWhere[0]).Where(dataAndWhere[1], dataAndWhere[2:]...).Update()
@ -36,7 +37,7 @@ func (m *Model) Update(dataAndWhere ...interface{}) (result sql.Result, err erro
}
defer func() {
if err == nil {
m.checkAndRemoveCache()
m.checkAndRemoveCache(ctx)
}
}()
if m.data == nil {
@ -46,11 +47,11 @@ func (m *Model) Update(dataAndWhere ...interface{}) (result sql.Result, err erro
updateData = m.data
reflectInfo = reflection.OriginTypeAndKind(updateData)
fieldNameUpdate = m.getSoftFieldNameUpdated()
conditionWhere, conditionExtra, conditionArgs = m.formatCondition(false, false)
conditionWhere, conditionExtra, conditionArgs = m.formatCondition(ctx, false, false)
)
switch reflectInfo.OriginKind {
case reflect.Map, reflect.Struct:
dataMap := m.db.ConvertDataForRecord(m.GetCtx(), m.data)
dataMap := m.db.ConvertDataForRecord(ctx, m.data)
// Automatically update the record updating time.
if !m.unscoped && fieldNameUpdate != "" {
dataMap[fieldNameUpdate] = gtime.Now().String()
@ -89,7 +90,7 @@ func (m *Model) Update(dataAndWhere ...interface{}) (result sql.Result, err erro
Condition: conditionStr,
Args: m.mergeArguments(conditionArgs),
}
return in.Next(m.GetCtx())
return in.Next(ctx)
}
// Increment increments a column's value by a given amount.