1
0
mirror of https://github.com/gogf/gf.git synced 2025-04-05 11:18:50 +08:00

enhance: improve FormatUpsert implements for pgsql (#3349)

This commit is contained in:
oldme 2024-03-06 19:05:13 +08:00 committed by GitHub
parent df15d70466
commit 97fcd9d726
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 118 additions and 62 deletions

View File

@ -7,7 +7,7 @@
// Package pgsql implements gdb.Driver, which supports operations for database PostgreSQL.
//
// Note:
// 1. It does not support Save/Replace features.
// 1. It does not support Replace features.
// 2. It does not support Insert Ignore features.
package pgsql

View File

@ -9,13 +9,10 @@ package pgsql
import (
"context"
"database/sql"
"fmt"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gconv"
)
// DoInsert inserts or updates data forF given table.
@ -47,55 +44,3 @@ func (d *Driver) DoInsert(ctx context.Context, link gdb.Link, table string, list
}
return d.Core.DoInsert(ctx, link, table, list, option)
}
// FormatUpsert returns SQL clause of type upsert for PgSQL.
// For example: ON CONFLICT (id) DO UPDATE SET ...
func (d *Driver) FormatUpsert(columns []string, list gdb.List, option gdb.DoInsertOption) (string, error) {
if len(option.OnConflict) == 0 {
return "", gerror.New("Please specify conflict columns")
}
var onDuplicateStr string
if option.OnDuplicateStr != "" {
onDuplicateStr = option.OnDuplicateStr
} else if len(option.OnDuplicateMap) > 0 {
for k, v := range option.OnDuplicateMap {
if len(onDuplicateStr) > 0 {
onDuplicateStr += ","
}
switch v.(type) {
case gdb.Raw, *gdb.Raw:
onDuplicateStr += fmt.Sprintf(
"%s=%s",
d.Core.QuoteWord(k),
v,
)
default:
onDuplicateStr += fmt.Sprintf(
"%s=EXCLUDED.%s",
d.Core.QuoteWord(k),
d.Core.QuoteWord(gconv.String(v)),
)
}
}
} else {
for _, column := range columns {
// If it's SAVE operation, do not automatically update the creating time.
if d.Core.IsSoftCreatedFieldName(column) {
continue
}
if len(onDuplicateStr) > 0 {
onDuplicateStr += ","
}
onDuplicateStr += fmt.Sprintf(
"%s=EXCLUDED.%s",
d.Core.QuoteWord(column),
d.Core.QuoteWord(column),
)
}
}
conflictKeys := gstr.Join(option.OnConflict, ",")
return fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET ", conflictKeys) + onDuplicateStr, nil
}

View File

@ -0,0 +1,68 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package pgsql
import (
"fmt"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gconv"
)
// FormatUpsert returns SQL clause of type upsert for PgSQL.
// For example: ON CONFLICT (id) DO UPDATE SET ...
func (d *Driver) FormatUpsert(columns []string, list gdb.List, option gdb.DoInsertOption) (string, error) {
if len(option.OnConflict) == 0 {
return "", gerror.New("Please specify conflict columns")
}
var onDuplicateStr string
if option.OnDuplicateStr != "" {
onDuplicateStr = option.OnDuplicateStr
} else if len(option.OnDuplicateMap) > 0 {
for k, v := range option.OnDuplicateMap {
if len(onDuplicateStr) > 0 {
onDuplicateStr += ","
}
switch v.(type) {
case gdb.Raw, *gdb.Raw:
onDuplicateStr += fmt.Sprintf(
"%s=%s",
d.Core.QuoteWord(k),
v,
)
default:
onDuplicateStr += fmt.Sprintf(
"%s=EXCLUDED.%s",
d.Core.QuoteWord(k),
d.Core.QuoteWord(gconv.String(v)),
)
}
}
} else {
for _, column := range columns {
// If it's SAVE operation, do not automatically update the creating time.
if d.Core.IsSoftCreatedFieldName(column) {
continue
}
if len(onDuplicateStr) > 0 {
onDuplicateStr += ","
}
onDuplicateStr += fmt.Sprintf(
"%s=EXCLUDED.%s",
d.Core.QuoteWord(column),
d.Core.QuoteWord(column),
)
}
}
conflictKeys := gstr.Join(option.OnConflict, ",")
return fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET ", conflictKeys) + onDuplicateStr, nil
}

View File

@ -7,6 +7,7 @@
package pgsql_test
import (
"database/sql"
"fmt"
"testing"
@ -260,16 +261,58 @@ func Test_Model_Save(t *testing.T) {
table := createTable()
defer dropTable(table)
gtest.C(t, func(t *gtest.T) {
result, err := db.Model(table).Data(g.Map{
type User struct {
Id int
Passport string
Password string
NickName string
CreateTime *gtime.Time
}
var (
user User
count int
result sql.Result
err error
)
result, err = db.Model(table).Data(g.Map{
"id": 1,
"passport": "t111",
"password": "25d55ad283aa400af464c76d713c07ad",
"nickname": "T111",
"create_time": "2018-10-24 10:00:00",
"passport": "p1",
"password": "pw1",
"nickname": "n1",
"create_time": CreateTime,
}).OnConflict("id").Save()
t.AssertNil(err)
t.AssertNil(nil)
n, _ := result.RowsAffected()
t.Assert(n, 1)
err = db.Model(table).Scan(&user)
t.Assert(err, nil)
t.Assert(user.Id, 1)
t.Assert(user.Passport, "p1")
t.Assert(user.Password, "pw1")
t.Assert(user.NickName, "n1")
t.Assert(user.CreateTime.String(), CreateTime)
_, err = db.Model(table).Data(g.Map{
"id": 1,
"passport": "p1",
"password": "pw2",
"nickname": "n2",
"create_time": CreateTime,
}).OnConflict("id").Save()
t.AssertNil(err)
err = db.Model(table).Scan(&user)
t.Assert(err, nil)
t.Assert(user.Passport, "p1")
t.Assert(user.Password, "pw2")
t.Assert(user.NickName, "n2")
t.Assert(user.CreateTime.String(), CreateTime)
count, err = db.Model(table).Count()
t.Assert(err, nil)
t.Assert(count, 1)
})
}