1
0
mirror of https://github.com/gogf/gf.git synced 2025-04-05 03:05:05 +08:00

fix(database/gdb): gdb.Counter not work in OnDuplicate (#4073)

This commit is contained in:
CyJaySong 2024-12-26 18:18:35 +08:00 committed by GitHub
parent 9ce2409659
commit 80f57d1c24
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 264 additions and 46 deletions

View File

@ -2812,6 +2812,28 @@ func Test_Model_OnDuplicate(t *testing.T) {
}) })
} }
func Test_Model_OnDuplicateWithCounter(t *testing.T) {
table := createInitTable()
defer dropTable(table)
gtest.C(t, func(t *gtest.T) {
data := g.Map{
"id": 1,
"passport": "pp1",
"password": "pw1",
"nickname": "n1",
"create_time": "2016-06-06",
}
_, err := db.Model(table).OnConflict("id").OnDuplicate(g.Map{
"id": gdb.Counter{Field: "id", Value: 999999},
}).Data(data).Save()
t.AssertNil(err)
one, err := db.Model(table).WherePri(1).One()
t.AssertNil(err)
t.AssertNil(one)
})
}
func Test_Model_OnDuplicateEx(t *testing.T) { func Test_Model_OnDuplicateEx(t *testing.T) {
table := createInitTable() table := createInitTable()
defer dropTable(table) defer dropTable(table)

View File

@ -40,6 +40,25 @@ func (d *Driver) FormatUpsert(columns []string, list gdb.List, option gdb.DoInse
d.Core.QuoteWord(k), d.Core.QuoteWord(k),
v, v,
) )
case gdb.Counter, *gdb.Counter:
var counter gdb.Counter
switch value := v.(type) {
case gdb.Counter:
counter = value
case *gdb.Counter:
counter = *value
}
operator, columnVal := "+", counter.Value
if columnVal < 0 {
operator, columnVal = "-", -columnVal
}
onDuplicateStr += fmt.Sprintf(
"%s=EXCLUDED.%s%s%s",
d.QuoteWord(k),
d.QuoteWord(counter.Field),
operator,
gconv.String(columnVal),
)
default: default:
onDuplicateStr += fmt.Sprintf( onDuplicateStr += fmt.Sprintf(
"%s=EXCLUDED.%s", "%s=EXCLUDED.%s",

View File

@ -521,6 +521,28 @@ func Test_Model_OnDuplicate(t *testing.T) {
}) })
} }
func Test_Model_OnDuplicateWithCounter(t *testing.T) {
table := createInitTable()
defer dropTable(table)
gtest.C(t, func(t *gtest.T) {
data := g.Map{
"id": 1,
"passport": "pp1",
"password": "pw1",
"nickname": "n1",
"create_time": "2016-06-06",
}
_, err := db.Model(table).OnConflict("id").OnDuplicate(g.Map{
"id": gdb.Counter{Field: "id", Value: 999999},
}).Data(data).Save()
t.AssertNil(err)
one, err := db.Model(table).WherePri(1).One()
t.AssertNil(err)
t.AssertNil(one)
})
}
func Test_Model_OnDuplicateEx(t *testing.T) { func Test_Model_OnDuplicateEx(t *testing.T) {
table := createInitTable() table := createInitTable()
defer dropTable(table) defer dropTable(table)

View File

@ -40,6 +40,25 @@ func (d *Driver) FormatUpsert(columns []string, list gdb.List, option gdb.DoInse
d.Core.QuoteWord(k), d.Core.QuoteWord(k),
v, v,
) )
case gdb.Counter, *gdb.Counter:
var counter gdb.Counter
switch value := v.(type) {
case gdb.Counter:
counter = value
case *gdb.Counter:
counter = *value
}
operator, columnVal := "+", counter.Value
if columnVal < 0 {
operator, columnVal = "-", -columnVal
}
onDuplicateStr += fmt.Sprintf(
"%s=EXCLUDED.%s%s%s",
d.QuoteWord(k),
d.QuoteWord(counter.Field),
operator,
gconv.String(columnVal),
)
default: default:
onDuplicateStr += fmt.Sprintf( onDuplicateStr += fmt.Sprintf(
"%s=EXCLUDED.%s", "%s=EXCLUDED.%s",

View File

@ -4324,3 +4324,25 @@ func Test_OrderRandom(t *testing.T) {
t.Assert(len(result), TableSize) t.Assert(len(result), TableSize)
}) })
} }
func Test_Model_OnDuplicateWithCounter(t *testing.T) {
table := createInitTable()
defer dropTable(table)
gtest.C(t, func(t *gtest.T) {
data := g.Map{
"id": 1,
"passport": "pp1",
"password": "pw1",
"nickname": "n1",
"create_time": "2016-06-06",
}
_, err := db.Model(table).OnConflict("id").OnDuplicate(g.Map{
"id": gdb.Counter{Field: "id", Value: 999999},
}).Data(data).Save()
t.AssertNil(err)
one, err := db.Model(table).WherePri(1).One()
t.AssertNil(err)
t.AssertNil(one)
})
}

View File

@ -0,0 +1,90 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package sqlitecgo
import (
"fmt"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gconv"
)
// FormatUpsert returns SQL clause of type upsert for SQLite.
// For example: ON CONFLICT (id) DO UPDATE SET ...
func (d *Driver) FormatUpsert(columns []string, list gdb.List, option gdb.DoInsertOption) (string, error) {
if len(option.OnConflict) == 0 {
return "", gerror.NewCode(
gcode.CodeMissingParameter, `Please specify conflict columns`,
)
}
var onDuplicateStr string
if option.OnDuplicateStr != "" {
onDuplicateStr = option.OnDuplicateStr
} else if len(option.OnDuplicateMap) > 0 {
for k, v := range option.OnDuplicateMap {
if len(onDuplicateStr) > 0 {
onDuplicateStr += ","
}
switch v.(type) {
case gdb.Raw, *gdb.Raw:
onDuplicateStr += fmt.Sprintf(
"%s=%s",
d.Core.QuoteWord(k),
v,
)
case gdb.Counter, *gdb.Counter:
var counter gdb.Counter
switch value := v.(type) {
case gdb.Counter:
counter = value
case *gdb.Counter:
counter = *value
}
operator, columnVal := "+", counter.Value
if columnVal < 0 {
operator, columnVal = "-", -columnVal
}
onDuplicateStr += fmt.Sprintf(
"%s=EXCLUDED.%s%s%s",
d.QuoteWord(k),
d.QuoteWord(counter.Field),
operator,
gconv.String(columnVal),
)
default:
onDuplicateStr += fmt.Sprintf(
"%s=EXCLUDED.%s",
d.Core.QuoteWord(k),
d.Core.QuoteWord(gconv.String(v)),
)
}
}
} else {
for _, column := range columns {
// If it's SAVE operation, do not automatically update the creating time.
if d.Core.IsSoftCreatedFieldName(column) {
continue
}
if len(onDuplicateStr) > 0 {
onDuplicateStr += ","
}
onDuplicateStr += fmt.Sprintf(
"%s=EXCLUDED.%s",
d.Core.QuoteWord(column),
d.Core.QuoteWord(column),
)
}
}
conflictKeys := gstr.Join(option.OnConflict, ",")
return fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET ", conflictKeys) + onDuplicateStr, nil
}

View File

@ -10,8 +10,6 @@ import (
"context" "context"
"github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/text/gstr" "github.com/gogf/gf/v2/text/gstr"
) )
@ -26,14 +24,6 @@ func (d *Driver) DoFilter(
case gstr.HasPrefix(sql, gdb.InsertOperationReplace): case gstr.HasPrefix(sql, gdb.InsertOperationReplace):
sql = "INSERT OR REPLACE" + sql[len(gdb.InsertOperationReplace):] sql = "INSERT OR REPLACE" + sql[len(gdb.InsertOperationReplace):]
default:
if gstr.Contains(sql, gdb.InsertOnDuplicateKeyUpdate) {
return sql, args, gerror.NewCode(
gcode.CodeNotSupported,
`Save operation is not supported by sqlite driver`,
)
}
} }
return d.Core.DoFilter(ctx, link, sql, args) return d.Core.DoFilter(ctx, link, sql, args)
} }

View File

@ -12,6 +12,10 @@ import (
"github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/database/gdb"
) )
const (
tablesSqlTmp = `SELECT NAME FROM SQLITE_MASTER WHERE TYPE='table' ORDER BY NAME`
)
// Tables retrieves and returns the tables of current schema. // Tables retrieves and returns the tables of current schema.
// It's mainly used in cli tool chain for automatically generating the models. // It's mainly used in cli tool chain for automatically generating the models.
func (d *Driver) Tables(ctx context.Context, schema ...string) (tables []string, err error) { func (d *Driver) Tables(ctx context.Context, schema ...string) (tables []string, err error) {
@ -21,11 +25,7 @@ func (d *Driver) Tables(ctx context.Context, schema ...string) (tables []string,
return nil, err return nil, err
} }
result, err = d.DoSelect( result, err = d.DoSelect(ctx, link, tablesSqlTmp)
ctx,
link,
`SELECT NAME FROM SQLITE_MASTER WHERE TYPE='table' ORDER BY NAME`,
)
if err != nil { if err != nil {
return return
} }

View File

@ -11,8 +11,6 @@ import (
"github.com/gogf/gf/v2/container/garray" "github.com/gogf/gf/v2/container/garray"
"github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gctx" "github.com/gogf/gf/v2/os/gctx"
"github.com/gogf/gf/v2/os/gfile" "github.com/gogf/gf/v2/os/gfile"
@ -27,9 +25,6 @@ var (
configNode gdb.ConfigNode configNode gdb.ConfigNode
dbDir = gfile.Temp("sqlite") dbDir = gfile.Temp("sqlite")
ctx = gctx.New() ctx = gctx.New()
// Error
ErrorSave = gerror.NewCode(gcode.CodeNotSupported, `Save operation is not supported by sqlite driver`)
) )
const ( const (

View File

@ -375,7 +375,7 @@ func Test_Model_Save(t *testing.T) {
"nickname": "oldme", "nickname": "oldme",
"create_time": CreateTime, "create_time": CreateTime,
}).OnConflict("id").Save() }).OnConflict("id").Save()
t.Assert(err, ErrorSave) t.AssertNil(err)
}) })
} }
@ -4361,3 +4361,25 @@ func TestResult_Structs1(t *testing.T) {
t.Assert(array[1].Name, "smith") t.Assert(array[1].Name, "smith")
}) })
} }
func Test_Model_OnDuplicateWithCounter(t *testing.T) {
table := createInitTable()
defer dropTable(table)
gtest.C(t, func(t *gtest.T) {
data := g.Map{
"id": 1,
"passport": "pp1",
"password": "pw1",
"nickname": "n1",
"create_time": "2016-06-06",
}
_, err := db.Model(table).OnConflict("id").OnDuplicate(g.Map{
"id": gdb.Counter{Field: "id", Value: 999999},
}).Data(data).Save()
t.AssertNil(err)
one, err := db.Model(table).WherePri(1).One()
t.AssertNil(err)
t.AssertNil(one)
})
}

View File

@ -583,24 +583,8 @@ func (c *Core) DoUpdate(ctx context.Context, link Link, table string, data inter
switch kind { switch kind {
case reflect.Map, reflect.Struct: case reflect.Map, reflect.Struct:
var ( var (
fields []string fields []string
dataMap map[string]interface{} dataMap map[string]interface{}
counterHandler = func(column string, counter Counter) {
if counter.Value != 0 {
column = c.QuoteWord(column)
var (
columnRef = c.QuoteWord(counter.Field)
columnVal = counter.Value
operator = "+"
)
if columnVal < 0 {
operator = "-"
columnVal = -columnVal
}
fields = append(fields, fmt.Sprintf("%s=%s%s?", column, columnRef, operator))
params = append(params, columnVal)
}
}
) )
dataMap, err = c.ConvertDataForRecord(ctx, data, table) dataMap, err = c.ConvertDataForRecord(ctx, data, table)
if err != nil { if err != nil {
@ -620,13 +604,21 @@ func (c *Core) DoUpdate(ctx context.Context, link Link, table string, data inter
} }
for _, k := range keysInSequence { for _, k := range keysInSequence {
v := dataMap[k] v := dataMap[k]
switch value := v.(type) { switch v.(type) {
case *Counter: case Counter, *Counter:
counterHandler(k, *value) var counter Counter
switch value := v.(type) {
case Counter: case Counter:
counterHandler(k, value) counter = value
case *Counter:
counter = *value
}
if counter.Value == 0 {
continue
}
operator, columnVal := c.getCounterAlter(counter)
fields = append(fields, fmt.Sprintf("%s=%s%s?", c.QuoteWord(k), c.QuoteWord(counter.Field), operator))
params = append(params, columnVal)
default: default:
if s, ok := v.(Raw); ok { if s, ok := v.(Raw); ok {
fields = append(fields, c.QuoteWord(k)+"="+gconv.String(s)) fields = append(fields, c.QuoteWord(k)+"="+gconv.String(s))
@ -796,3 +788,12 @@ func (c *Core) IsSoftCreatedFieldName(fieldName string) bool {
func (c *Core) FormatSqlBeforeExecuting(sql string, args []interface{}) (newSql string, newArgs []interface{}) { func (c *Core) FormatSqlBeforeExecuting(sql string, args []interface{}) (newSql string, newArgs []interface{}) {
return handleSliceAndStructArgsForSql(sql, args) return handleSliceAndStructArgsForSql(sql, args)
} }
// getCounterAlter
func (c *Core) getCounterAlter(counter Counter) (operator string, columnVal float64) {
operator, columnVal = "+", counter.Value
if columnVal < 0 {
operator, columnVal = "-", -columnVal
}
return
}

View File

@ -388,6 +388,22 @@ func (c *Core) FormatUpsert(columns []string, list List, option DoInsertOption)
c.QuoteWord(k), c.QuoteWord(k),
v, v,
) )
case Counter, *Counter:
var counter Counter
switch value := v.(type) {
case Counter:
counter = value
case *Counter:
counter = *value
}
operator, columnVal := c.getCounterAlter(counter)
onDuplicateStr += fmt.Sprintf(
"%s=%s%s%s",
c.QuoteWord(k),
c.QuoteWord(counter.Field),
operator,
gconv.String(columnVal),
)
default: default:
onDuplicateStr += fmt.Sprintf( onDuplicateStr += fmt.Sprintf(
"%s=VALUES(%s)", "%s=VALUES(%s)",