mirror of
https://github.com/gogf/gf.git
synced 2025-04-05 11:18:50 +08:00
imrove context handling for package gdb
This commit is contained in:
parent
cc01629b57
commit
09ba1bf1fb
@ -522,8 +522,11 @@ func getConfigNodeByWeight(cg ConfigGroup) *ConfigNode {
|
||||
// The parameter `master` specifies whether retrieves master node connection if
|
||||
// master-slave nodes are configured.
|
||||
func (c *Core) getSqlDb(master bool, schema ...string) (sqlDb *sql.DB, err error) {
|
||||
var (
|
||||
ctx = c.db.GetCtx()
|
||||
node *ConfigNode
|
||||
)
|
||||
// Load balance.
|
||||
var node *ConfigNode
|
||||
if c.group != "" {
|
||||
node, err = getConfigNodeByGroup(c.group, master)
|
||||
if err != nil {
|
||||
@ -549,20 +552,12 @@ func (c *Core) getSqlDb(master bool, schema ...string) (sqlDb *sql.DB, err error
|
||||
}
|
||||
// Cache the underlying connection pool object by node.
|
||||
v := c.links.GetOrSetFuncLock(node.String(), func() interface{} {
|
||||
intlog.Printf(
|
||||
c.db.GetCtx(),
|
||||
`open new connection, master:%#v, config:%#v, node:%#v`,
|
||||
master, c.config, node,
|
||||
)
|
||||
intlog.Printf(ctx, `open new connection, master:%#v, config:%#v, node:%#v`, master, c.config, node)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
intlog.Printf(c.db.GetCtx(), `open new connection failed: %v, %#v`, err, node)
|
||||
intlog.Printf(ctx, `open new connection failed: %v, %#v`, err, node)
|
||||
} else {
|
||||
intlog.Printf(
|
||||
c.db.GetCtx(),
|
||||
`open new connection success, master:%#v, config:%#v, node:%#v`,
|
||||
master, c.config, node,
|
||||
)
|
||||
intlog.Printf(ctx, `open new connection success, master:%#v, config:%#v, node:%#v`, master, c.config, node)
|
||||
}
|
||||
}()
|
||||
|
||||
|
@ -67,9 +67,9 @@ func (c *Core) GetCtx() context.Context {
|
||||
}
|
||||
|
||||
// GetCtxTimeout returns the context and cancel function for specified timeout type.
|
||||
func (c *Core) GetCtxTimeout(timeoutType int, ctx context.Context) (context.Context, context.CancelFunc) {
|
||||
func (c *Core) GetCtxTimeout(ctx context.Context, timeoutType int) (context.Context, context.CancelFunc) {
|
||||
if ctx == nil {
|
||||
ctx = c.GetCtx()
|
||||
ctx = c.db.GetCtx()
|
||||
} else {
|
||||
ctx = context.WithValue(ctx, "WrappedByGetCtxTimeout", nil)
|
||||
}
|
||||
@ -255,17 +255,18 @@ func (c *Core) GetCount(ctx context.Context, sql string, args ...interface{}) (i
|
||||
|
||||
// Union does "(SELECT xxx FROM xxx) UNION (SELECT xxx FROM xxx) ..." statement.
|
||||
func (c *Core) Union(unions ...*Model) *Model {
|
||||
return c.doUnion(unionTypeNormal, unions...)
|
||||
var ctx = c.db.GetCtx()
|
||||
return c.doUnion(ctx, unionTypeNormal, unions...)
|
||||
}
|
||||
|
||||
// UnionAll does "(SELECT xxx FROM xxx) UNION ALL (SELECT xxx FROM xxx) ..." statement.
|
||||
func (c *Core) UnionAll(unions ...*Model) *Model {
|
||||
return c.doUnion(unionTypeAll, unions...)
|
||||
var ctx = c.db.GetCtx()
|
||||
return c.doUnion(ctx, unionTypeAll, unions...)
|
||||
}
|
||||
|
||||
func (c *Core) doUnion(unionType int, unions ...*Model) *Model {
|
||||
func (c *Core) doUnion(ctx context.Context, unionType int, unions ...*Model) *Model {
|
||||
var (
|
||||
ctx = c.db.GetCtx()
|
||||
unionTypeStr string
|
||||
composedSqlStr string
|
||||
composedArgs = make([]interface{}, 0)
|
||||
@ -289,10 +290,11 @@ func (c *Core) doUnion(unionType int, unions ...*Model) *Model {
|
||||
|
||||
// PingMaster pings the master node to check authentication or keeps the connection alive.
|
||||
func (c *Core) PingMaster() error {
|
||||
var ctx = c.db.GetCtx()
|
||||
if master, err := c.db.Master(); err != nil {
|
||||
return err
|
||||
} else {
|
||||
if err = master.PingContext(c.GetCtx()); err != nil {
|
||||
if err = master.PingContext(ctx); err != nil {
|
||||
err = gerror.WrapCode(gcode.CodeDbOperationError, err, `master.Ping failed`)
|
||||
}
|
||||
return err
|
||||
@ -301,10 +303,11 @@ func (c *Core) PingMaster() error {
|
||||
|
||||
// PingSlave pings the slave node to check authentication or keeps the connection alive.
|
||||
func (c *Core) PingSlave() error {
|
||||
var ctx = c.db.GetCtx()
|
||||
if slave, err := c.db.Slave(); err != nil {
|
||||
return err
|
||||
} else {
|
||||
if err = slave.PingContext(c.GetCtx()); err != nil {
|
||||
if err = slave.PingContext(ctx); err != nil {
|
||||
err = gerror.WrapCode(gcode.CodeDbOperationError, err, `slave.Ping failed`)
|
||||
}
|
||||
return err
|
||||
@ -663,21 +666,22 @@ func (c *Core) writeSqlToLogger(ctx context.Context, sql *Sql) {
|
||||
|
||||
// HasTable determine whether the table name exists in the database.
|
||||
func (c *Core) HasTable(name string) (bool, error) {
|
||||
result, err := c.GetCache().GetOrSetFuncLock(
|
||||
c.GetCtx(),
|
||||
fmt.Sprintf(`HasTable: %s`, name),
|
||||
func(ctx context.Context) (interface{}, error) {
|
||||
tableList, err := c.db.Tables(ctx)
|
||||
if err != nil {
|
||||
return false, err
|
||||
var (
|
||||
ctx = c.db.GetCtx()
|
||||
cacheKey = fmt.Sprintf(`HasTable: %s`, name)
|
||||
)
|
||||
result, err := c.GetCache().GetOrSetFuncLock(ctx, cacheKey, func(ctx context.Context) (interface{}, error) {
|
||||
tableList, err := c.db.Tables(ctx)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
for _, table := range tableList {
|
||||
if table == name {
|
||||
return true, nil
|
||||
}
|
||||
for _, table := range tableList {
|
||||
if table == name {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}, 0,
|
||||
}
|
||||
return false, nil
|
||||
}, 0,
|
||||
)
|
||||
if err != nil {
|
||||
return false, err
|
||||
|
@ -71,7 +71,7 @@ func (c *Core) doBeginCtx(ctx context.Context) (*TX, error) {
|
||||
func (c *Core) Transaction(ctx context.Context, f func(ctx context.Context, tx *TX) error) (err error) {
|
||||
var tx *TX
|
||||
if ctx == nil {
|
||||
ctx = c.GetCtx()
|
||||
ctx = c.db.GetCtx()
|
||||
}
|
||||
// Check transaction object from context.
|
||||
tx = TXFromCtx(ctx, c.db.GetGroup())
|
||||
|
@ -234,7 +234,7 @@ func (c *Core) DoCommit(ctx context.Context, in DoCommitInput) (out DoCommitOutp
|
||||
out.RawResult = sqlStmt
|
||||
|
||||
case SqlTypeStmtExecContext:
|
||||
ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctxTimeoutTypeExec, ctx)
|
||||
ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctx, ctxTimeoutTypeExec)
|
||||
defer cancelFuncForTimeout()
|
||||
if c.db.GetDryRun() {
|
||||
sqlResult = new(SqlResult)
|
||||
@ -244,13 +244,13 @@ func (c *Core) DoCommit(ctx context.Context, in DoCommitInput) (out DoCommitOutp
|
||||
out.RawResult = sqlResult
|
||||
|
||||
case SqlTypeStmtQueryContext:
|
||||
ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctxTimeoutTypeQuery, ctx)
|
||||
ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctx, ctxTimeoutTypeQuery)
|
||||
defer cancelFuncForTimeout()
|
||||
stmtSqlRows, err = in.Stmt.QueryContext(ctx, in.Args...)
|
||||
out.RawResult = stmtSqlRows
|
||||
|
||||
case SqlTypeStmtQueryRowContext:
|
||||
ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctxTimeoutTypeQuery, ctx)
|
||||
ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctx, ctxTimeoutTypeQuery)
|
||||
defer cancelFuncForTimeout()
|
||||
stmtSqlRow = in.Stmt.QueryRowContext(ctx, in.Args...)
|
||||
out.RawResult = stmtSqlRow
|
||||
|
@ -150,11 +150,12 @@ func (c *Core) Tables(schema ...string) (tables []string, err error) {
|
||||
//
|
||||
// It does nothing in default.
|
||||
func (c *Core) TableFields(table string, schema ...string) (fields map[string]*TableField, err error) {
|
||||
var ctx = c.db.GetCtx()
|
||||
// It does nothing if given table is empty, especially in sub-query.
|
||||
if table == "" {
|
||||
return map[string]*TableField{}, nil
|
||||
}
|
||||
return c.db.TableFields(c.GetCtx(), table, schema...)
|
||||
return c.db.TableFields(ctx, table, schema...)
|
||||
}
|
||||
|
||||
// HasField determine whether the field exists in the table.
|
||||
|
@ -38,6 +38,7 @@ func (d *DriverMysql) New(core *Core, node *ConfigNode) (DB, error) {
|
||||
// Note that it converts time.Time argument to local timezone in default.
|
||||
func (d *DriverMysql) Open(config *ConfigNode) (db *sql.DB, err error) {
|
||||
var (
|
||||
ctx = d.GetCtx()
|
||||
source string
|
||||
underlyingDriverName = "mysql"
|
||||
)
|
||||
@ -56,7 +57,7 @@ func (d *DriverMysql) Open(config *ConfigNode) (db *sql.DB, err error) {
|
||||
source = fmt.Sprintf("%s&loc=%s", source, url.QueryEscape(config.Timezone))
|
||||
}
|
||||
}
|
||||
intlog.Printf(d.GetCtx(), "Open: %s", source)
|
||||
intlog.Printf(ctx, "Open: %s", source)
|
||||
if db, err = sql.Open(underlyingDriverName, source); err != nil {
|
||||
err = gerror.WrapCodef(
|
||||
gcode.CodeDbOperationError, err,
|
||||
|
@ -8,6 +8,7 @@ package gdb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
@ -364,7 +365,7 @@ func isKeyValueCanBeOmitEmpty(omitEmpty bool, whereType string, key, value inter
|
||||
}
|
||||
|
||||
// formatWhereHolder formats where statement and its arguments for `Where` and `Having` statements.
|
||||
func formatWhereHolder(db DB, in formatWhereHolderInput) (newWhere string, newArgs []interface{}) {
|
||||
func formatWhereHolder(ctx context.Context, db DB, in formatWhereHolderInput) (newWhere string, newArgs []interface{}) {
|
||||
var (
|
||||
buffer = bytes.NewBuffer(nil)
|
||||
reflectInfo = reflection.OriginValueAndKind(in.Where)
|
||||
@ -393,7 +394,7 @@ func formatWhereHolder(db DB, in formatWhereHolderInput) (newWhere string, newAr
|
||||
}
|
||||
|
||||
case reflect.Struct:
|
||||
// If the `where` parameter is DO struct, it then adds `OmitNil` option for this condition,
|
||||
// If the `where` parameter is `DO` struct, it then adds `OmitNil` option for this condition,
|
||||
// which will filter all nil parameters in `where`.
|
||||
if isDoStruct(in.Where) {
|
||||
in.OmitNil = true
|
||||
@ -523,7 +524,9 @@ func formatWhereHolder(db DB, in formatWhereHolderInput) (newWhere string, newAr
|
||||
whereStr, _ = gregex.ReplaceStringFunc(`(\?)`, whereStr, func(s string) string {
|
||||
index++
|
||||
if i+len(newArgs) == index {
|
||||
sqlWithHolder, holderArgs := model.getFormattedSqlAndArgs(model.GetCtx(), queryTypeNormal, false)
|
||||
sqlWithHolder, holderArgs := model.getFormattedSqlAndArgs(
|
||||
ctx, queryTypeNormal, false,
|
||||
)
|
||||
newArgs = append(newArgs, holderArgs...)
|
||||
// Automatically adding the brackets.
|
||||
return "(" + sqlWithHolder + ")"
|
||||
|
@ -93,6 +93,7 @@ const (
|
||||
// db.Model("? AS a, ? AS b", subQuery1, subQuery2)
|
||||
func (c *Core) Model(tableNameQueryOrStruct ...interface{}) *Model {
|
||||
var (
|
||||
ctx = c.db.GetCtx()
|
||||
tableStr string
|
||||
tableName string
|
||||
extraArgs []interface{}
|
||||
@ -105,7 +106,7 @@ func (c *Core) Model(tableNameQueryOrStruct ...interface{}) *Model {
|
||||
Where: conditionStr,
|
||||
Args: tableNameQueryOrStruct[1:],
|
||||
}
|
||||
tableStr, extraArgs = formatWhereHolder(c.db, formatWhereHolderInput{
|
||||
tableStr, extraArgs = formatWhereHolder(ctx, c.db, formatWhereHolderInput{
|
||||
ModelWhereHolder: whereHolder,
|
||||
OmitNil: false,
|
||||
OmitEmpty: false,
|
||||
|
@ -7,6 +7,7 @@
|
||||
package gdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/gogf/gf/v2/internal/intlog"
|
||||
@ -44,11 +45,9 @@ func (m *Model) Cache(option CacheOption) *Model {
|
||||
|
||||
// checkAndRemoveCache checks and removes the cache in insert/update/delete statement if
|
||||
// cache feature is enabled.
|
||||
func (m *Model) checkAndRemoveCache() {
|
||||
func (m *Model) checkAndRemoveCache(ctx context.Context) {
|
||||
if m.cacheEnabled && m.cacheOption.Duration < 0 && len(m.cacheOption.Name) > 0 {
|
||||
ctx := m.GetCtx()
|
||||
_, err := m.db.GetCache().Remove(ctx, m.cacheOption.Name)
|
||||
if err != nil {
|
||||
if _, err := m.db.GetCache().Remove(ctx, m.cacheOption.Name); err != nil {
|
||||
intlog.Errorf(ctx, `%+v`, err)
|
||||
}
|
||||
}
|
||||
|
@ -20,17 +20,18 @@ import (
|
||||
// The optional parameter `where` is the same as the parameter of Model.Where function,
|
||||
// see Model.Where.
|
||||
func (m *Model) Delete(where ...interface{}) (result sql.Result, err error) {
|
||||
var ctx = m.GetCtx()
|
||||
if len(where) > 0 {
|
||||
return m.Where(where[0], where[1:]...).Delete()
|
||||
}
|
||||
defer func() {
|
||||
if err == nil {
|
||||
m.checkAndRemoveCache()
|
||||
m.checkAndRemoveCache(ctx)
|
||||
}
|
||||
}()
|
||||
var (
|
||||
fieldNameDelete = m.getSoftFieldNameDeleted()
|
||||
conditionWhere, conditionExtra, conditionArgs = m.formatCondition(false, false)
|
||||
conditionWhere, conditionExtra, conditionArgs = m.formatCondition(ctx, false, false)
|
||||
)
|
||||
// Soft deleting.
|
||||
if !m.unscoped && fieldNameDelete != "" {
|
||||
@ -47,7 +48,7 @@ func (m *Model) Delete(where ...interface{}) (result sql.Result, err error) {
|
||||
Condition: conditionWhere + conditionExtra,
|
||||
Args: append([]interface{}{gtime.Now().String()}, conditionArgs...),
|
||||
}
|
||||
return in.Next(m.GetCtx())
|
||||
return in.Next(ctx)
|
||||
}
|
||||
conditionStr := conditionWhere + conditionExtra
|
||||
if !gstr.ContainsI(conditionStr, " WHERE ") {
|
||||
@ -69,5 +70,5 @@ func (m *Model) Delete(where ...interface{}) (result sql.Result, err error) {
|
||||
Condition: conditionStr,
|
||||
Args: conditionArgs,
|
||||
}
|
||||
return in.Next(m.GetCtx())
|
||||
return in.Next(ctx)
|
||||
}
|
||||
|
@ -7,6 +7,7 @@
|
||||
package gdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"reflect"
|
||||
|
||||
@ -38,7 +39,10 @@ func (m *Model) Batch(batch int) *Model {
|
||||
// Data(g.Map{"uid": 10000, "name":"john"})
|
||||
// Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"}).
|
||||
func (m *Model) Data(data ...interface{}) *Model {
|
||||
model := m.getModel()
|
||||
var (
|
||||
ctx = m.GetCtx()
|
||||
model = m.getModel()
|
||||
)
|
||||
if len(data) > 1 {
|
||||
if s := gconv.String(data[0]); gstr.Contains(s, "?") {
|
||||
model.data = s
|
||||
@ -83,7 +87,7 @@ func (m *Model) Data(data ...interface{}) *Model {
|
||||
}
|
||||
list := make(List, reflectInfo.OriginValue.Len())
|
||||
for i := 0; i < reflectInfo.OriginValue.Len(); i++ {
|
||||
list[i] = m.db.ConvertDataForRecord(m.GetCtx(), reflectInfo.OriginValue.Index(i).Interface())
|
||||
list[i] = m.db.ConvertDataForRecord(ctx, reflectInfo.OriginValue.Index(i).Interface())
|
||||
}
|
||||
model.data = list
|
||||
|
||||
@ -100,15 +104,15 @@ func (m *Model) Data(data ...interface{}) *Model {
|
||||
list = make(List, len(array))
|
||||
)
|
||||
for i := 0; i < len(array); i++ {
|
||||
list[i] = m.db.ConvertDataForRecord(m.GetCtx(), array[i])
|
||||
list[i] = m.db.ConvertDataForRecord(ctx, array[i])
|
||||
}
|
||||
model.data = list
|
||||
} else {
|
||||
model.data = m.db.ConvertDataForRecord(m.GetCtx(), data[0])
|
||||
model.data = m.db.ConvertDataForRecord(ctx, data[0])
|
||||
}
|
||||
|
||||
case reflect.Map:
|
||||
model.data = m.db.ConvertDataForRecord(m.GetCtx(), data[0])
|
||||
model.data = m.db.ConvertDataForRecord(ctx, data[0])
|
||||
|
||||
default:
|
||||
model.data = data[0]
|
||||
@ -164,18 +168,20 @@ func (m *Model) OnDuplicateEx(onDuplicateEx ...interface{}) *Model {
|
||||
// The optional parameter `data` is the same as the parameter of Model.Data function,
|
||||
// see Model.Data.
|
||||
func (m *Model) Insert(data ...interface{}) (result sql.Result, err error) {
|
||||
var ctx = m.GetCtx()
|
||||
if len(data) > 0 {
|
||||
return m.Data(data...).Insert()
|
||||
}
|
||||
return m.doInsertWithOption(InsertOptionDefault)
|
||||
return m.doInsertWithOption(ctx, InsertOptionDefault)
|
||||
}
|
||||
|
||||
// InsertAndGetId performs action Insert and returns the last insert id that automatically generated.
|
||||
func (m *Model) InsertAndGetId(data ...interface{}) (lastInsertId int64, err error) {
|
||||
var ctx = m.GetCtx()
|
||||
if len(data) > 0 {
|
||||
return m.Data(data...).InsertAndGetId()
|
||||
}
|
||||
result, err := m.doInsertWithOption(InsertOptionDefault)
|
||||
result, err := m.doInsertWithOption(ctx, InsertOptionDefault)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@ -186,20 +192,22 @@ func (m *Model) InsertAndGetId(data ...interface{}) (lastInsertId int64, err err
|
||||
// The optional parameter `data` is the same as the parameter of Model.Data function,
|
||||
// see Model.Data.
|
||||
func (m *Model) InsertIgnore(data ...interface{}) (result sql.Result, err error) {
|
||||
var ctx = m.GetCtx()
|
||||
if len(data) > 0 {
|
||||
return m.Data(data...).InsertIgnore()
|
||||
}
|
||||
return m.doInsertWithOption(InsertOptionIgnore)
|
||||
return m.doInsertWithOption(ctx, InsertOptionIgnore)
|
||||
}
|
||||
|
||||
// Replace does "REPLACE INTO ..." statement for the model.
|
||||
// The optional parameter `data` is the same as the parameter of Model.Data function,
|
||||
// see Model.Data.
|
||||
func (m *Model) Replace(data ...interface{}) (result sql.Result, err error) {
|
||||
var ctx = m.GetCtx()
|
||||
if len(data) > 0 {
|
||||
return m.Data(data...).Replace()
|
||||
}
|
||||
return m.doInsertWithOption(InsertOptionReplace)
|
||||
return m.doInsertWithOption(ctx, InsertOptionReplace)
|
||||
}
|
||||
|
||||
// Save does "INSERT INTO ... ON DUPLICATE KEY UPDATE..." statement for the model.
|
||||
@ -209,17 +217,18 @@ func (m *Model) Replace(data ...interface{}) (result sql.Result, err error) {
|
||||
// It updates the record if there's primary or unique index in the saving data,
|
||||
// or else it inserts a new record into the table.
|
||||
func (m *Model) Save(data ...interface{}) (result sql.Result, err error) {
|
||||
var ctx = m.GetCtx()
|
||||
if len(data) > 0 {
|
||||
return m.Data(data...).Save()
|
||||
}
|
||||
return m.doInsertWithOption(InsertOptionSave)
|
||||
return m.doInsertWithOption(ctx, InsertOptionSave)
|
||||
}
|
||||
|
||||
// doInsertWithOption inserts data with option parameter.
|
||||
func (m *Model) doInsertWithOption(insertOption int) (result sql.Result, err error) {
|
||||
func (m *Model) doInsertWithOption(ctx context.Context, insertOption int) (result sql.Result, err error) {
|
||||
defer func() {
|
||||
if err == nil {
|
||||
m.checkAndRemoveCache()
|
||||
m.checkAndRemoveCache(ctx)
|
||||
}
|
||||
}()
|
||||
if m.data == nil {
|
||||
@ -246,11 +255,11 @@ func (m *Model) doInsertWithOption(insertOption int) (result sql.Result, err err
|
||||
case List:
|
||||
list = value
|
||||
for i, v := range list {
|
||||
list[i] = m.db.ConvertDataForRecord(m.GetCtx(), v)
|
||||
list[i] = m.db.ConvertDataForRecord(ctx, v)
|
||||
}
|
||||
|
||||
case Map:
|
||||
list = List{m.db.ConvertDataForRecord(m.GetCtx(), value)}
|
||||
list = List{m.db.ConvertDataForRecord(ctx, value)}
|
||||
|
||||
default:
|
||||
reflectInfo := reflection.OriginValueAndKind(newData)
|
||||
@ -259,21 +268,21 @@ func (m *Model) doInsertWithOption(insertOption int) (result sql.Result, err err
|
||||
case reflect.Slice, reflect.Array:
|
||||
list = make(List, reflectInfo.OriginValue.Len())
|
||||
for i := 0; i < reflectInfo.OriginValue.Len(); i++ {
|
||||
list[i] = m.db.ConvertDataForRecord(m.GetCtx(), reflectInfo.OriginValue.Index(i).Interface())
|
||||
list[i] = m.db.ConvertDataForRecord(ctx, reflectInfo.OriginValue.Index(i).Interface())
|
||||
}
|
||||
|
||||
case reflect.Map:
|
||||
list = List{m.db.ConvertDataForRecord(m.GetCtx(), value)}
|
||||
list = List{m.db.ConvertDataForRecord(ctx, value)}
|
||||
|
||||
case reflect.Struct:
|
||||
if v, ok := value.(iInterfaces); ok {
|
||||
array := v.Interfaces()
|
||||
list = make(List, len(array))
|
||||
for i := 0; i < len(array); i++ {
|
||||
list[i] = m.db.ConvertDataForRecord(m.GetCtx(), array[i])
|
||||
list[i] = m.db.ConvertDataForRecord(ctx, array[i])
|
||||
}
|
||||
} else {
|
||||
list = List{m.db.ConvertDataForRecord(m.GetCtx(), value)}
|
||||
list = List{m.db.ConvertDataForRecord(ctx, value)}
|
||||
}
|
||||
|
||||
default:
|
||||
@ -323,7 +332,7 @@ func (m *Model) doInsertWithOption(insertOption int) (result sql.Result, err err
|
||||
Data: list,
|
||||
Option: doInsertOption,
|
||||
}
|
||||
return in.Next(m.GetCtx())
|
||||
return in.Next(ctx)
|
||||
}
|
||||
|
||||
func (m *Model) formatDoInsertOption(insertOption int, columnNames []string) (option DoInsertOption, err error) {
|
||||
|
@ -30,7 +30,8 @@ import (
|
||||
// The optional parameter `where` is the same as the parameter of Model.Where function,
|
||||
// see Model.Where.
|
||||
func (m *Model) All(where ...interface{}) (Result, error) {
|
||||
return m.doGetAll(m.GetCtx(), false, where...)
|
||||
var ctx = m.GetCtx()
|
||||
return m.doGetAll(ctx, false, where...)
|
||||
}
|
||||
|
||||
// doGetAll does "SELECT FROM ..." statement for the model.
|
||||
@ -44,7 +45,7 @@ func (m *Model) doGetAll(ctx context.Context, limit1 bool, where ...interface{})
|
||||
if len(where) > 0 {
|
||||
return m.Where(where[0], where[1:]...).All()
|
||||
}
|
||||
sqlWithHolder, holderArgs := m.getFormattedSqlAndArgs(m.GetCtx(), queryTypeNormal, limit1)
|
||||
sqlWithHolder, holderArgs := m.getFormattedSqlAndArgs(ctx, queryTypeNormal, limit1)
|
||||
return m.doGetAllBySql(ctx, queryTypeNormal, sqlWithHolder, holderArgs...)
|
||||
}
|
||||
|
||||
@ -131,10 +132,11 @@ func (m *Model) Chunk(size int, handler ChunkHandler) {
|
||||
// The optional parameter `where` is the same as the parameter of Model.Where function,
|
||||
// see Model.Where.
|
||||
func (m *Model) One(where ...interface{}) (Record, error) {
|
||||
var ctx = m.GetCtx()
|
||||
if len(where) > 0 {
|
||||
return m.Where(where[0], where[1:]...).One()
|
||||
}
|
||||
all, err := m.doGetAll(m.GetCtx(), true)
|
||||
all, err := m.doGetAll(ctx, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -151,6 +153,7 @@ func (m *Model) One(where ...interface{}) (Record, error) {
|
||||
// and fieldsAndWhere[1:] is treated as where condition fields.
|
||||
// Also see Model.Fields and Model.Where functions.
|
||||
func (m *Model) Value(fieldsAndWhere ...interface{}) (Value, error) {
|
||||
var ctx = m.GetCtx()
|
||||
if len(fieldsAndWhere) > 0 {
|
||||
if len(fieldsAndWhere) > 2 {
|
||||
return m.Fields(gconv.String(fieldsAndWhere[0])).Where(fieldsAndWhere[1], fieldsAndWhere[2:]...).Value()
|
||||
@ -163,7 +166,6 @@ func (m *Model) Value(fieldsAndWhere ...interface{}) (Value, error) {
|
||||
var (
|
||||
all Result
|
||||
err error
|
||||
ctx = m.GetCtx()
|
||||
)
|
||||
if all, err = m.doGetAll(ctx, true); err != nil {
|
||||
return nil, err
|
||||
@ -373,11 +375,11 @@ func (m *Model) ScanList(structSlicePointer interface{}, bindToAttrName string,
|
||||
// The optional parameter `where` is the same as the parameter of Model.Where function,
|
||||
// see Model.Where.
|
||||
func (m *Model) Count(where ...interface{}) (int, error) {
|
||||
var ctx = m.GetCtx()
|
||||
if len(where) > 0 {
|
||||
return m.Where(where[0], where[1:]...).Count()
|
||||
}
|
||||
var (
|
||||
ctx = m.GetCtx()
|
||||
sqlWithHolder, holderArgs = m.getFormattedSqlAndArgs(ctx, queryTypeCount, false)
|
||||
all, err = m.doGetAllBySql(ctx, queryTypeCount, sqlWithHolder, holderArgs...)
|
||||
)
|
||||
@ -566,7 +568,7 @@ func (m *Model) doGetAllBySql(ctx context.Context, queryType int, sql string, ar
|
||||
if cacheKey != "" && err == nil {
|
||||
if m.cacheOption.Duration < 0 {
|
||||
if _, errCache := cacheObj.Remove(ctx, cacheKey); errCache != nil {
|
||||
intlog.Errorf(m.GetCtx(), `%+v`, errCache)
|
||||
intlog.Errorf(ctx, `%+v`, errCache)
|
||||
}
|
||||
} else {
|
||||
// In case of Cache Penetration.
|
||||
@ -574,7 +576,7 @@ func (m *Model) doGetAllBySql(ctx context.Context, queryType int, sql string, ar
|
||||
result = Result{}
|
||||
}
|
||||
if errCache := cacheObj.Set(ctx, cacheKey, result, m.cacheOption.Duration); errCache != nil {
|
||||
intlog.Errorf(m.GetCtx(), `%+v`, errCache)
|
||||
intlog.Errorf(ctx, `%+v`, errCache)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -595,7 +597,7 @@ func (m *Model) getFormattedSqlAndArgs(ctx context.Context, queryType int, limit
|
||||
sqlWithHolder = fmt.Sprintf("SELECT %s FROM (%s) AS T", queryFields, m.rawSql)
|
||||
return sqlWithHolder, nil
|
||||
}
|
||||
conditionWhere, conditionExtra, conditionArgs := m.formatCondition(false, true)
|
||||
conditionWhere, conditionExtra, conditionArgs := m.formatCondition(ctx, false, true)
|
||||
sqlWithHolder = fmt.Sprintf("SELECT %s FROM %s%s", queryFields, m.tables, conditionWhere+conditionExtra)
|
||||
if len(m.groupBy) > 0 {
|
||||
sqlWithHolder = fmt.Sprintf("SELECT COUNT(1) FROM (%s) count_alias", sqlWithHolder)
|
||||
@ -603,7 +605,7 @@ func (m *Model) getFormattedSqlAndArgs(ctx context.Context, queryType int, limit
|
||||
return sqlWithHolder, conditionArgs
|
||||
|
||||
default:
|
||||
conditionWhere, conditionExtra, conditionArgs := m.formatCondition(limit1, false)
|
||||
conditionWhere, conditionExtra, conditionArgs := m.formatCondition(ctx, limit1, false)
|
||||
// Raw SQL Model, especially for UNION/UNION ALL featured SQL.
|
||||
if m.rawSql != "" {
|
||||
sqlWithHolder = fmt.Sprintf(
|
||||
@ -627,7 +629,7 @@ func (m *Model) getFormattedSqlAndArgs(ctx context.Context, queryType int, limit
|
||||
// Note that this function does not change any attribute value of the `m`.
|
||||
//
|
||||
// The parameter `limit1` specifies whether limits querying only one record if m.limit is not set.
|
||||
func (m *Model) formatCondition(limit1 bool, isCountStatement bool) (conditionWhere string, conditionExtra string, conditionArgs []interface{}) {
|
||||
func (m *Model) formatCondition(ctx context.Context, limit1 bool, isCountStatement bool) (conditionWhere string, conditionExtra string, conditionArgs []interface{}) {
|
||||
autoPrefix := ""
|
||||
if gstr.Contains(m.tables, " JOIN ") {
|
||||
autoPrefix = m.db.GetCore().QuoteWord(
|
||||
@ -647,7 +649,7 @@ func (m *Model) formatCondition(limit1 bool, isCountStatement bool) (conditionWh
|
||||
switch holder.Operator {
|
||||
case whereHolderOperatorWhere:
|
||||
if conditionWhere == "" {
|
||||
newWhere, newArgs := formatWhereHolder(m.db, formatWhereHolderInput{
|
||||
newWhere, newArgs := formatWhereHolder(ctx, m.db, formatWhereHolderInput{
|
||||
ModelWhereHolder: holder,
|
||||
OmitNil: m.option&optionOmitNilWhere > 0,
|
||||
OmitEmpty: m.option&optionOmitEmptyWhere > 0,
|
||||
@ -663,7 +665,7 @@ func (m *Model) formatCondition(limit1 bool, isCountStatement bool) (conditionWh
|
||||
fallthrough
|
||||
|
||||
case whereHolderOperatorAnd:
|
||||
newWhere, newArgs := formatWhereHolder(m.db, formatWhereHolderInput{
|
||||
newWhere, newArgs := formatWhereHolder(ctx, m.db, formatWhereHolderInput{
|
||||
ModelWhereHolder: holder,
|
||||
OmitNil: m.option&optionOmitNilWhere > 0,
|
||||
OmitEmpty: m.option&optionOmitEmptyWhere > 0,
|
||||
@ -682,7 +684,7 @@ func (m *Model) formatCondition(limit1 bool, isCountStatement bool) (conditionWh
|
||||
}
|
||||
|
||||
case whereHolderOperatorOr:
|
||||
newWhere, newArgs := formatWhereHolder(m.db, formatWhereHolderInput{
|
||||
newWhere, newArgs := formatWhereHolder(ctx, m.db, formatWhereHolderInput{
|
||||
ModelWhereHolder: holder,
|
||||
OmitNil: m.option&optionOmitNilWhere > 0,
|
||||
OmitEmpty: m.option&optionOmitEmptyWhere > 0,
|
||||
@ -733,7 +735,7 @@ func (m *Model) formatCondition(limit1 bool, isCountStatement bool) (conditionWh
|
||||
Args: gconv.Interfaces(m.having[1]),
|
||||
Prefix: autoPrefix,
|
||||
}
|
||||
havingStr, havingArgs := formatWhereHolder(m.db, formatWhereHolderInput{
|
||||
havingStr, havingArgs := formatWhereHolder(ctx, m.db, formatWhereHolderInput{
|
||||
ModelWhereHolder: havingHolder,
|
||||
OmitNil: m.option&optionOmitNilWhere > 0,
|
||||
OmitEmpty: m.option&optionOmitEmptyWhere > 0,
|
||||
|
@ -25,6 +25,7 @@ import (
|
||||
// and dataAndWhere[1:] is treated as where condition fields.
|
||||
// Also see Model.Data and Model.Where functions.
|
||||
func (m *Model) Update(dataAndWhere ...interface{}) (result sql.Result, err error) {
|
||||
var ctx = m.GetCtx()
|
||||
if len(dataAndWhere) > 0 {
|
||||
if len(dataAndWhere) > 2 {
|
||||
return m.Data(dataAndWhere[0]).Where(dataAndWhere[1], dataAndWhere[2:]...).Update()
|
||||
@ -36,7 +37,7 @@ func (m *Model) Update(dataAndWhere ...interface{}) (result sql.Result, err erro
|
||||
}
|
||||
defer func() {
|
||||
if err == nil {
|
||||
m.checkAndRemoveCache()
|
||||
m.checkAndRemoveCache(ctx)
|
||||
}
|
||||
}()
|
||||
if m.data == nil {
|
||||
@ -46,11 +47,11 @@ func (m *Model) Update(dataAndWhere ...interface{}) (result sql.Result, err erro
|
||||
updateData = m.data
|
||||
reflectInfo = reflection.OriginTypeAndKind(updateData)
|
||||
fieldNameUpdate = m.getSoftFieldNameUpdated()
|
||||
conditionWhere, conditionExtra, conditionArgs = m.formatCondition(false, false)
|
||||
conditionWhere, conditionExtra, conditionArgs = m.formatCondition(ctx, false, false)
|
||||
)
|
||||
switch reflectInfo.OriginKind {
|
||||
case reflect.Map, reflect.Struct:
|
||||
dataMap := m.db.ConvertDataForRecord(m.GetCtx(), m.data)
|
||||
dataMap := m.db.ConvertDataForRecord(ctx, m.data)
|
||||
// Automatically update the record updating time.
|
||||
if !m.unscoped && fieldNameUpdate != "" {
|
||||
dataMap[fieldNameUpdate] = gtime.Now().String()
|
||||
@ -89,7 +90,7 @@ func (m *Model) Update(dataAndWhere ...interface{}) (result sql.Result, err erro
|
||||
Condition: conditionStr,
|
||||
Args: m.mergeArguments(conditionArgs),
|
||||
}
|
||||
return in.Next(m.GetCtx())
|
||||
return in.Next(ctx)
|
||||
}
|
||||
|
||||
// Increment increments a column's value by a given amount.
|
||||
|
Loading…
x
Reference in New Issue
Block a user