1
0
mirror of https://github.com/gogf/gf.git synced 2025-04-05 11:18:50 +08:00

enhance: Save operation support for contrib/drivers/dm (#3404)

This commit is contained in:
oldme 2024-03-20 19:18:25 +08:00 committed by GitHub
parent 164aad48c3
commit cade0775e8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 137 additions and 193 deletions

View File

@ -10,9 +10,9 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"strings" "strings"
"github.com/gogf/gf/v2/container/gset"
"github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/errors/gcode" "github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror" "github.com/gogf/gf/v2/errors/gerror"
@ -32,59 +32,71 @@ func (d *Driver) DoInsert(
) )
case gdb.InsertOptionSave: case gdb.InsertOptionSave:
// This syntax currently only supports design tables whose primary key is ID. return d.doSave(ctx, link, table, list, option)
listLength := len(list)
if listLength == 0 {
return nil, gerror.NewCode(
gcode.CodeInvalidRequest, `Save operation list is empty by dm driver`,
)
} }
var ( return d.Core.DoInsert(ctx, link, table, list, option)
keysSort []string
charL, charR = d.GetChars()
)
// Column names need to be aligned in the syntax
for k := range list[0] {
keysSort = append(keysSort, k)
}
var char = struct {
charL string
charR string
valueCharL string
valueCharR string
duplicateKey string
keys []string
}{
charL: charL,
charR: charR,
valueCharL: "'",
valueCharR: "'",
// TODO:: Need to dynamically set the primary key of the table
duplicateKey: "ID",
keys: keysSort,
} }
// insertKeys: Handle valid keys that need to be inserted and updated // doSave support upsert for dm
func (d *Driver) doSave(ctx context.Context,
link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption,
) (result sql.Result, err error) {
if len(option.OnConflict) == 0 {
return nil, gerror.NewCode(
gcode.CodeMissingParameter, `Please specify conflict columns`,
)
}
if len(list) == 0 {
return nil, gerror.NewCode(
gcode.CodeInvalidRequest, `Save operation list is empty by oracle driver`,
)
}
var (
one = list[0]
charL, charR = d.GetChars()
valueCharL, valueCharR = "'", "'"
conflictKeys = option.OnConflict
conflictKeySet = gset.New(false)
// insertKeys: Handle valid keys that need to be inserted
// insertValues: Handle values that need to be inserted // insertValues: Handle values that need to be inserted
// updateValues: Handle values that need to be updated // updateValues: Handle values that need to be updated
// queryValues: Handle only one insert with column name // queryValues: Handle data that need to be upsert
insertKeys, insertValues, updateValues, queryValues := parseValue(list[0], char) queryValues, insertKeys, insertValues, updateValues []string
// unionValues: Handling values that need to be inserted and updated )
unionValues := parseUnion(list[1:], char)
// conflictKeys slice type conv to set type
for _, conflictKey := range conflictKeys {
conflictKeySet.Add(gstr.ToUpper(conflictKey))
}
for key, value := range one {
saveValue := gconv.String(value)
queryValues = append(
queryValues,
fmt.Sprintf(
valueCharL+"%s"+valueCharR+" AS "+charL+"%s"+charR,
saveValue, key,
),
)
insertKeys = append(insertKeys, charL+key+charR)
insertValues = append(insertValues, "T2."+charL+key+charR)
// filter conflict keys in updateValues
if !conflictKeySet.Contains(key) {
updateValues = append(
updateValues,
fmt.Sprintf(`T1.%s = T2.%s`, charL+key+charR, charL+key+charR),
)
}
}
batchResult := new(gdb.SqlResult) batchResult := new(gdb.SqlResult)
// parseSql(): sqlStr := parseSqlForUpsert(table, queryValues, insertKeys, insertValues, updateValues, conflictKeys)
// MERGE INTO {{table}} T1
// USING ( SELECT {{queryValues}} FROM DUAL
// {{unionValues}} ) T2
// ON (T1.{{duplicateKey}} = T2.{{duplicateKey}})
// WHEN NOT MATCHED THEN
// INSERT {{insertKeys}} VALUES {{insertValues}}
// WHEN MATCHED THEN
// UPDATE SET {{updateValues}}
sqlStr := parseSql(
insertKeys, insertValues, updateValues, queryValues, unionValues, table, char.duplicateKey,
)
r, err := d.DoExec(ctx, link, sqlStr) r, err := d.DoExec(ctx, link, sqlStr)
if err != nil { if err != nil {
return r, err return r, err
@ -97,111 +109,41 @@ func (d *Driver) DoInsert(
} }
return batchResult, nil return batchResult, nil
} }
return d.Core.DoInsert(ctx, link, table, list, option)
}
func parseValue(listOne gdb.Map, char struct { // parseSqlForUpsert
charL string // MERGE INTO {{table}} T1
charR string // USING ( SELECT {{queryValues}} FROM DUAL T2
valueCharL string // ON (T1.{{duplicateKey}} = T2.{{duplicateKey}} AND ...)
valueCharR string // WHEN NOT MATCHED THEN
duplicateKey string // INSERT {{insertKeys}} VALUES {{insertValues}}
keys []string // WHEN MATCHED THEN
}) (insertKeys []string, insertValues []string, updateValues []string, queryValues []string) { // UPDATE SET {{updateValues}}
for _, column := range char.keys { func parseSqlForUpsert(table string,
if listOne[column] == nil { queryValues, insertKeys, insertValues, updateValues, duplicateKey []string,
// remove unassigned struct object
continue
}
insertKeys = append(insertKeys, char.charL+column+char.charR)
insertValues = append(insertValues, "T2."+char.charL+column+char.charR)
if column != char.duplicateKey {
updateValues = append(
updateValues,
fmt.Sprintf(`T1.%s = T2.%s`, char.charL+column+char.charR, char.charL+column+char.charR),
)
}
saveValue := gconv.String(listOne[column])
queryValues = append(
queryValues,
fmt.Sprintf(
char.valueCharL+"%s"+char.valueCharR+" AS "+char.charL+"%s"+char.charR,
saveValue, column,
),
)
}
return
}
func parseUnion(list gdb.List, char struct {
charL string
charR string
valueCharL string
valueCharR string
duplicateKey string
keys []string
}) (unionValues []string) {
for _, mapper := range list {
var saveValue []string
for _, column := range char.keys {
if mapper[column] == nil {
continue
}
// va := reflect.ValueOf(mapper[column])
// ty := reflect.TypeOf(mapper[column])
// switch ty.Kind() {
// case reflect.String:
// saveValue = append(saveValue, char.valueCharL+va.String()+char.valueCharR)
// case reflect.Int:
// saveValue = append(saveValue, strconv.FormatInt(va.Int(), 10))
// case reflect.Int64:
// saveValue = append(saveValue, strconv.FormatInt(va.Int(), 10))
// default:
// // The fish has no chance getting here.
// // Nothing to do.
// }
saveValue = append(saveValue,
fmt.Sprintf(
char.valueCharL+"%s"+char.valueCharR,
gconv.String(mapper[column]),
))
}
unionValues = append(
unionValues,
fmt.Sprintf(`UNION ALL SELECT %s FROM DUAL`, strings.Join(saveValue, ",")),
)
}
return
}
func parseSql(
insertKeys, insertValues, updateValues, queryValues, unionValues []string, table, duplicateKey string,
) (sqlStr string) { ) (sqlStr string) {
var ( var (
queryValueStr = strings.Join(queryValues, ",") queryValueStr = strings.Join(queryValues, ",")
unionValueStr = strings.Join(unionValues, " ")
insertKeyStr = strings.Join(insertKeys, ",") insertKeyStr = strings.Join(insertKeys, ",")
insertValueStr = strings.Join(insertValues, ",") insertValueStr = strings.Join(insertValues, ",")
updateValueStr = strings.Join(updateValues, ",") updateValueStr = strings.Join(updateValues, ",")
pattern = gstr.Trim(` duplicateKeyStr string
MERGE INTO %s T1 USING (SELECT %s FROM DUAL %s) T2 ON %s pattern = gstr.Trim(`MERGE INTO %s T1 USING (SELECT %s FROM DUAL) T2 ON (%s) WHEN NOT MATCHED THEN INSERT(%s) VALUES (%s) WHEN MATCHED THEN UPDATE SET %s;`)
WHEN NOT MATCHED
THEN
INSERT(%s) VALUES (%s)
WHEN MATCHED
THEN
UPDATE SET %s;
COMMIT;
`)
) )
return fmt.Sprintf(
pattern, for index, keys := range duplicateKey {
table, queryValueStr, unionValueStr, if index != 0 {
fmt.Sprintf("(T1.%s = T2.%s)", duplicateKey, duplicateKey), duplicateKeyStr += " AND "
insertKeyStr, insertValueStr, updateValueStr, }
duplicateTmp := fmt.Sprintf("T1.%s = T2.%s", keys, keys)
duplicateKeyStr += duplicateTmp
}
return fmt.Sprintf(pattern,
table,
queryValueStr,
duplicateKeyStr,
insertKeyStr,
insertValueStr,
updateValueStr,
) )
} }

View File

@ -7,6 +7,7 @@
package dm_test package dm_test
import ( import (
"database/sql"
"fmt" "fmt"
"strings" "strings"
"testing" "testing"
@ -138,52 +139,53 @@ func Test_DB_Query(t *testing.T) {
} }
func TestModelSave(t *testing.T) { func TestModelSave(t *testing.T) {
table := "A_tables" table := createTable("test")
createInitTable(table) defer dropTable(table)
gtest.C(t, func(t *gtest.T) { gtest.C(t, func(t *gtest.T) {
data := []User{ type User struct {
{ Id int
ID: 100, AccountName string
AccountName: "user_100", AttrIndex int
AttrIndex: 100,
CreatedTime: time.Now(),
},
} }
_, err := db.Model(table).Data(data).Save() var (
gtest.Assert(err, nil) user User
count int
result sql.Result
err error
)
db.SetDebug(true)
data2 := []User{ result, err = db.Model(table).Data(g.Map{
{ "id": 1,
ID: 101, "accountName": "ac1",
AccountName: "user_101", "attrIndex": 100,
}, }).OnConflict("id").Save()
}
_, err = db.Model(table).Data(&data2).Save()
gtest.Assert(err, nil)
data3 := []User{ t.AssertNil(err)
{ n, _ := result.RowsAffected()
ID: 10, t.Assert(n, 1)
AccountName: "user_10",
PwdReset: 10,
},
}
_, err = db.Model(table).Save(data3)
gtest.Assert(err, nil)
data4 := []User{ err = db.Model(table).Scan(&user)
{ t.AssertNil(err)
ID: 9, t.Assert(user.Id, 1)
AccountName: "user_9", t.Assert(user.AccountName, "ac1")
CreatedTime: time.Now(), t.Assert(user.AttrIndex, 100)
},
}
_, err = db.Model(table).Save(&data4)
gtest.Assert(err, nil)
// TODO:: Should be Supported 'Replace' Operation _, err = db.Model(table).Data(g.Map{
// _, err = db.Schema(TestDBName).Replace(ctx, "DoInsert", data, 10) "id": 1,
// gtest.Assert(err, nil) "accountName": "ac2",
"attrIndex": 200,
}).OnConflict("id").Save()
t.AssertNil(err)
err = db.Model(table).Scan(&user)
t.AssertNil(err)
t.Assert(user.AccountName, "ac2")
t.Assert(user.AttrIndex, 200)
count, err = db.Model(table).Count()
t.AssertNil(err)
t.Assert(count, 1)
}) })
} }