mirror of
https://github.com/gogf/gf.git
synced 2025-04-05 11:18:50 +08:00
enhance: Save
operation support for contrib/drivers/dm
(#3404)
This commit is contained in:
parent
164aad48c3
commit
cade0775e8
@ -10,9 +10,9 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gogf/gf/v2/container/gset"
|
||||||
"github.com/gogf/gf/v2/database/gdb"
|
"github.com/gogf/gf/v2/database/gdb"
|
||||||
"github.com/gogf/gf/v2/errors/gcode"
|
"github.com/gogf/gf/v2/errors/gcode"
|
||||||
"github.com/gogf/gf/v2/errors/gerror"
|
"github.com/gogf/gf/v2/errors/gerror"
|
||||||
@ -32,59 +32,71 @@ func (d *Driver) DoInsert(
|
|||||||
)
|
)
|
||||||
|
|
||||||
case gdb.InsertOptionSave:
|
case gdb.InsertOptionSave:
|
||||||
// This syntax currently only supports design tables whose primary key is ID.
|
return d.doSave(ctx, link, table, list, option)
|
||||||
listLength := len(list)
|
|
||||||
if listLength == 0 {
|
|
||||||
return nil, gerror.NewCode(
|
|
||||||
gcode.CodeInvalidRequest, `Save operation list is empty by dm driver`,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
var (
|
return d.Core.DoInsert(ctx, link, table, list, option)
|
||||||
keysSort []string
|
|
||||||
charL, charR = d.GetChars()
|
|
||||||
)
|
|
||||||
// Column names need to be aligned in the syntax
|
|
||||||
for k := range list[0] {
|
|
||||||
keysSort = append(keysSort, k)
|
|
||||||
}
|
|
||||||
var char = struct {
|
|
||||||
charL string
|
|
||||||
charR string
|
|
||||||
valueCharL string
|
|
||||||
valueCharR string
|
|
||||||
duplicateKey string
|
|
||||||
keys []string
|
|
||||||
}{
|
|
||||||
charL: charL,
|
|
||||||
charR: charR,
|
|
||||||
valueCharL: "'",
|
|
||||||
valueCharR: "'",
|
|
||||||
// TODO:: Need to dynamically set the primary key of the table
|
|
||||||
duplicateKey: "ID",
|
|
||||||
keys: keysSort,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// insertKeys: Handle valid keys that need to be inserted and updated
|
// doSave support upsert for dm
|
||||||
|
func (d *Driver) doSave(ctx context.Context,
|
||||||
|
link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption,
|
||||||
|
) (result sql.Result, err error) {
|
||||||
|
if len(option.OnConflict) == 0 {
|
||||||
|
return nil, gerror.NewCode(
|
||||||
|
gcode.CodeMissingParameter, `Please specify conflict columns`,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(list) == 0 {
|
||||||
|
return nil, gerror.NewCode(
|
||||||
|
gcode.CodeInvalidRequest, `Save operation list is empty by oracle driver`,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
one = list[0]
|
||||||
|
charL, charR = d.GetChars()
|
||||||
|
valueCharL, valueCharR = "'", "'"
|
||||||
|
|
||||||
|
conflictKeys = option.OnConflict
|
||||||
|
conflictKeySet = gset.New(false)
|
||||||
|
|
||||||
|
// insertKeys: Handle valid keys that need to be inserted
|
||||||
// insertValues: Handle values that need to be inserted
|
// insertValues: Handle values that need to be inserted
|
||||||
// updateValues: Handle values that need to be updated
|
// updateValues: Handle values that need to be updated
|
||||||
// queryValues: Handle only one insert with column name
|
// queryValues: Handle data that need to be upsert
|
||||||
insertKeys, insertValues, updateValues, queryValues := parseValue(list[0], char)
|
queryValues, insertKeys, insertValues, updateValues []string
|
||||||
// unionValues: Handling values that need to be inserted and updated
|
)
|
||||||
unionValues := parseUnion(list[1:], char)
|
|
||||||
|
// conflictKeys slice type conv to set type
|
||||||
|
for _, conflictKey := range conflictKeys {
|
||||||
|
conflictKeySet.Add(gstr.ToUpper(conflictKey))
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, value := range one {
|
||||||
|
saveValue := gconv.String(value)
|
||||||
|
queryValues = append(
|
||||||
|
queryValues,
|
||||||
|
fmt.Sprintf(
|
||||||
|
valueCharL+"%s"+valueCharR+" AS "+charL+"%s"+charR,
|
||||||
|
saveValue, key,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
insertKeys = append(insertKeys, charL+key+charR)
|
||||||
|
insertValues = append(insertValues, "T2."+charL+key+charR)
|
||||||
|
|
||||||
|
// filter conflict keys in updateValues
|
||||||
|
if !conflictKeySet.Contains(key) {
|
||||||
|
updateValues = append(
|
||||||
|
updateValues,
|
||||||
|
fmt.Sprintf(`T1.%s = T2.%s`, charL+key+charR, charL+key+charR),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
batchResult := new(gdb.SqlResult)
|
batchResult := new(gdb.SqlResult)
|
||||||
// parseSql():
|
sqlStr := parseSqlForUpsert(table, queryValues, insertKeys, insertValues, updateValues, conflictKeys)
|
||||||
// MERGE INTO {{table}} T1
|
|
||||||
// USING ( SELECT {{queryValues}} FROM DUAL
|
|
||||||
// {{unionValues}} ) T2
|
|
||||||
// ON (T1.{{duplicateKey}} = T2.{{duplicateKey}})
|
|
||||||
// WHEN NOT MATCHED THEN
|
|
||||||
// INSERT {{insertKeys}} VALUES {{insertValues}}
|
|
||||||
// WHEN MATCHED THEN
|
|
||||||
// UPDATE SET {{updateValues}}
|
|
||||||
sqlStr := parseSql(
|
|
||||||
insertKeys, insertValues, updateValues, queryValues, unionValues, table, char.duplicateKey,
|
|
||||||
)
|
|
||||||
r, err := d.DoExec(ctx, link, sqlStr)
|
r, err := d.DoExec(ctx, link, sqlStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return r, err
|
return r, err
|
||||||
@ -97,111 +109,41 @@ func (d *Driver) DoInsert(
|
|||||||
}
|
}
|
||||||
return batchResult, nil
|
return batchResult, nil
|
||||||
}
|
}
|
||||||
return d.Core.DoInsert(ctx, link, table, list, option)
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseValue(listOne gdb.Map, char struct {
|
// parseSqlForUpsert
|
||||||
charL string
|
// MERGE INTO {{table}} T1
|
||||||
charR string
|
// USING ( SELECT {{queryValues}} FROM DUAL T2
|
||||||
valueCharL string
|
// ON (T1.{{duplicateKey}} = T2.{{duplicateKey}} AND ...)
|
||||||
valueCharR string
|
// WHEN NOT MATCHED THEN
|
||||||
duplicateKey string
|
// INSERT {{insertKeys}} VALUES {{insertValues}}
|
||||||
keys []string
|
// WHEN MATCHED THEN
|
||||||
}) (insertKeys []string, insertValues []string, updateValues []string, queryValues []string) {
|
// UPDATE SET {{updateValues}}
|
||||||
for _, column := range char.keys {
|
func parseSqlForUpsert(table string,
|
||||||
if listOne[column] == nil {
|
queryValues, insertKeys, insertValues, updateValues, duplicateKey []string,
|
||||||
// remove unassigned struct object
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
insertKeys = append(insertKeys, char.charL+column+char.charR)
|
|
||||||
insertValues = append(insertValues, "T2."+char.charL+column+char.charR)
|
|
||||||
if column != char.duplicateKey {
|
|
||||||
updateValues = append(
|
|
||||||
updateValues,
|
|
||||||
fmt.Sprintf(`T1.%s = T2.%s`, char.charL+column+char.charR, char.charL+column+char.charR),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
saveValue := gconv.String(listOne[column])
|
|
||||||
queryValues = append(
|
|
||||||
queryValues,
|
|
||||||
fmt.Sprintf(
|
|
||||||
char.valueCharL+"%s"+char.valueCharR+" AS "+char.charL+"%s"+char.charR,
|
|
||||||
saveValue, column,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseUnion(list gdb.List, char struct {
|
|
||||||
charL string
|
|
||||||
charR string
|
|
||||||
valueCharL string
|
|
||||||
valueCharR string
|
|
||||||
duplicateKey string
|
|
||||||
keys []string
|
|
||||||
}) (unionValues []string) {
|
|
||||||
for _, mapper := range list {
|
|
||||||
var saveValue []string
|
|
||||||
for _, column := range char.keys {
|
|
||||||
if mapper[column] == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// va := reflect.ValueOf(mapper[column])
|
|
||||||
// ty := reflect.TypeOf(mapper[column])
|
|
||||||
// switch ty.Kind() {
|
|
||||||
// case reflect.String:
|
|
||||||
// saveValue = append(saveValue, char.valueCharL+va.String()+char.valueCharR)
|
|
||||||
|
|
||||||
// case reflect.Int:
|
|
||||||
// saveValue = append(saveValue, strconv.FormatInt(va.Int(), 10))
|
|
||||||
|
|
||||||
// case reflect.Int64:
|
|
||||||
// saveValue = append(saveValue, strconv.FormatInt(va.Int(), 10))
|
|
||||||
|
|
||||||
// default:
|
|
||||||
// // The fish has no chance getting here.
|
|
||||||
// // Nothing to do.
|
|
||||||
// }
|
|
||||||
saveValue = append(saveValue,
|
|
||||||
fmt.Sprintf(
|
|
||||||
char.valueCharL+"%s"+char.valueCharR,
|
|
||||||
gconv.String(mapper[column]),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
unionValues = append(
|
|
||||||
unionValues,
|
|
||||||
fmt.Sprintf(`UNION ALL SELECT %s FROM DUAL`, strings.Join(saveValue, ",")),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseSql(
|
|
||||||
insertKeys, insertValues, updateValues, queryValues, unionValues []string, table, duplicateKey string,
|
|
||||||
) (sqlStr string) {
|
) (sqlStr string) {
|
||||||
var (
|
var (
|
||||||
queryValueStr = strings.Join(queryValues, ",")
|
queryValueStr = strings.Join(queryValues, ",")
|
||||||
unionValueStr = strings.Join(unionValues, " ")
|
|
||||||
insertKeyStr = strings.Join(insertKeys, ",")
|
insertKeyStr = strings.Join(insertKeys, ",")
|
||||||
insertValueStr = strings.Join(insertValues, ",")
|
insertValueStr = strings.Join(insertValues, ",")
|
||||||
updateValueStr = strings.Join(updateValues, ",")
|
updateValueStr = strings.Join(updateValues, ",")
|
||||||
pattern = gstr.Trim(`
|
duplicateKeyStr string
|
||||||
MERGE INTO %s T1 USING (SELECT %s FROM DUAL %s) T2 ON %s
|
pattern = gstr.Trim(`MERGE INTO %s T1 USING (SELECT %s FROM DUAL) T2 ON (%s) WHEN NOT MATCHED THEN INSERT(%s) VALUES (%s) WHEN MATCHED THEN UPDATE SET %s;`)
|
||||||
WHEN NOT MATCHED
|
|
||||||
THEN
|
|
||||||
INSERT(%s) VALUES (%s)
|
|
||||||
WHEN MATCHED
|
|
||||||
THEN
|
|
||||||
UPDATE SET %s;
|
|
||||||
COMMIT;
|
|
||||||
`)
|
|
||||||
)
|
)
|
||||||
return fmt.Sprintf(
|
|
||||||
pattern,
|
for index, keys := range duplicateKey {
|
||||||
table, queryValueStr, unionValueStr,
|
if index != 0 {
|
||||||
fmt.Sprintf("(T1.%s = T2.%s)", duplicateKey, duplicateKey),
|
duplicateKeyStr += " AND "
|
||||||
insertKeyStr, insertValueStr, updateValueStr,
|
}
|
||||||
|
duplicateTmp := fmt.Sprintf("T1.%s = T2.%s", keys, keys)
|
||||||
|
duplicateKeyStr += duplicateTmp
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf(pattern,
|
||||||
|
table,
|
||||||
|
queryValueStr,
|
||||||
|
duplicateKeyStr,
|
||||||
|
insertKeyStr,
|
||||||
|
insertValueStr,
|
||||||
|
updateValueStr,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -7,6 +7,7 @@
|
|||||||
package dm_test
|
package dm_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
@ -138,52 +139,53 @@ func Test_DB_Query(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestModelSave(t *testing.T) {
|
func TestModelSave(t *testing.T) {
|
||||||
table := "A_tables"
|
table := createTable("test")
|
||||||
createInitTable(table)
|
defer dropTable(table)
|
||||||
gtest.C(t, func(t *gtest.T) {
|
gtest.C(t, func(t *gtest.T) {
|
||||||
data := []User{
|
type User struct {
|
||||||
{
|
Id int
|
||||||
ID: 100,
|
AccountName string
|
||||||
AccountName: "user_100",
|
AttrIndex int
|
||||||
AttrIndex: 100,
|
|
||||||
CreatedTime: time.Now(),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
_, err := db.Model(table).Data(data).Save()
|
var (
|
||||||
gtest.Assert(err, nil)
|
user User
|
||||||
|
count int
|
||||||
|
result sql.Result
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
db.SetDebug(true)
|
||||||
|
|
||||||
data2 := []User{
|
result, err = db.Model(table).Data(g.Map{
|
||||||
{
|
"id": 1,
|
||||||
ID: 101,
|
"accountName": "ac1",
|
||||||
AccountName: "user_101",
|
"attrIndex": 100,
|
||||||
},
|
}).OnConflict("id").Save()
|
||||||
}
|
|
||||||
_, err = db.Model(table).Data(&data2).Save()
|
|
||||||
gtest.Assert(err, nil)
|
|
||||||
|
|
||||||
data3 := []User{
|
t.AssertNil(err)
|
||||||
{
|
n, _ := result.RowsAffected()
|
||||||
ID: 10,
|
t.Assert(n, 1)
|
||||||
AccountName: "user_10",
|
|
||||||
PwdReset: 10,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
_, err = db.Model(table).Save(data3)
|
|
||||||
gtest.Assert(err, nil)
|
|
||||||
|
|
||||||
data4 := []User{
|
err = db.Model(table).Scan(&user)
|
||||||
{
|
t.AssertNil(err)
|
||||||
ID: 9,
|
t.Assert(user.Id, 1)
|
||||||
AccountName: "user_9",
|
t.Assert(user.AccountName, "ac1")
|
||||||
CreatedTime: time.Now(),
|
t.Assert(user.AttrIndex, 100)
|
||||||
},
|
|
||||||
}
|
|
||||||
_, err = db.Model(table).Save(&data4)
|
|
||||||
gtest.Assert(err, nil)
|
|
||||||
|
|
||||||
// TODO:: Should be Supported 'Replace' Operation
|
_, err = db.Model(table).Data(g.Map{
|
||||||
// _, err = db.Schema(TestDBName).Replace(ctx, "DoInsert", data, 10)
|
"id": 1,
|
||||||
// gtest.Assert(err, nil)
|
"accountName": "ac2",
|
||||||
|
"attrIndex": 200,
|
||||||
|
}).OnConflict("id").Save()
|
||||||
|
t.AssertNil(err)
|
||||||
|
|
||||||
|
err = db.Model(table).Scan(&user)
|
||||||
|
t.AssertNil(err)
|
||||||
|
t.Assert(user.AccountName, "ac2")
|
||||||
|
t.Assert(user.AttrIndex, 200)
|
||||||
|
|
||||||
|
count, err = db.Model(table).Count()
|
||||||
|
t.AssertNil(err)
|
||||||
|
t.Assert(count, 1)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user