mirror of
https://github.com/openimsdk/open-im-server.git
synced 2025-04-05 20:11:14 +08:00
fix: create database name (#1285)
This commit is contained in:
parent
38ab3e0ed7
commit
7722714251
@ -18,15 +18,12 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
mysqldriver "github.com/go-sql-driver/mysql"
|
||||
"gorm.io/driver/mysql"
|
||||
|
||||
"github.com/OpenIMSDK/tools/errs"
|
||||
"github.com/OpenIMSDK/tools/log"
|
||||
"github.com/OpenIMSDK/tools/mw/specialerror"
|
||||
|
||||
mysqldriver "github.com/go-sql-driver/mysql"
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
|
||||
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
@ -35,56 +32,80 @@ const (
|
||||
maxRetry = 100 // number of retries
|
||||
)
|
||||
|
||||
// newMysqlGormDB Initialize the database connection.
|
||||
func newMysqlGormDB() (*gorm.DB, error) {
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=true&loc=Local",
|
||||
config.Config.Mysql.Username, config.Config.Mysql.Password, config.Config.Mysql.Address[0], "mysql")
|
||||
type option struct {
|
||||
Username string
|
||||
Password string
|
||||
Address []string
|
||||
Database string
|
||||
LogLevel int
|
||||
SlowThreshold int
|
||||
MaxLifeTime int
|
||||
MaxOpenConn int
|
||||
MaxIdleConn int
|
||||
Connect func(dsn string, maxRetry int) (*gorm.DB, error)
|
||||
}
|
||||
|
||||
db, err := connectToDatabase(dsn, maxRetry)
|
||||
if err != nil {
|
||||
panic(err.Error() + " Open failed " + dsn)
|
||||
}
|
||||
sqlDB, err := db.DB()
|
||||
// newMysqlGormDB Initialize the database connection.
|
||||
func newMysqlGormDB(o *option) (*gorm.DB, error) {
|
||||
err := maybeCreateTable(o)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer sqlDB.Close()
|
||||
sql := fmt.Sprintf(
|
||||
"CREATE DATABASE IF NOT EXISTS %s default charset utf8mb4 COLLATE utf8mb4_unicode_ci;",
|
||||
config.Config.Mysql.Database,
|
||||
)
|
||||
err = db.Exec(sql).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init db %w", err)
|
||||
}
|
||||
dsn = fmt.Sprintf(
|
||||
"%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=true&loc=Local",
|
||||
config.Config.Mysql.Username,
|
||||
config.Config.Mysql.Password,
|
||||
config.Config.Mysql.Address[0],
|
||||
config.Config.Mysql.Database,
|
||||
)
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=true&loc=Local",
|
||||
o.Username, o.Password, o.Address[0], o.Database)
|
||||
sqlLogger := log.NewSqlLogger(
|
||||
logger.LogLevel(config.Config.Mysql.LogLevel),
|
||||
logger.LogLevel(o.LogLevel),
|
||||
true,
|
||||
time.Duration(config.Config.Mysql.SlowThreshold)*time.Millisecond,
|
||||
time.Duration(o.SlowThreshold)*time.Millisecond,
|
||||
)
|
||||
db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{
|
||||
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{
|
||||
Logger: sqlLogger,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sqlDB, err = db.DB()
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(config.Config.Mysql.MaxLifeTime))
|
||||
sqlDB.SetMaxOpenConns(config.Config.Mysql.MaxOpenConn)
|
||||
sqlDB.SetMaxIdleConns(config.Config.Mysql.MaxIdleConn)
|
||||
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(o.MaxLifeTime))
|
||||
sqlDB.SetMaxOpenConns(o.MaxOpenConn)
|
||||
sqlDB.SetMaxIdleConns(o.MaxIdleConn)
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// maybeCreateTable creates a database if it does not exists.
|
||||
func maybeCreateTable(o *option) error {
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=true&loc=Local",
|
||||
o.Username, o.Password, o.Address[0], "mysql")
|
||||
|
||||
var db *gorm.DB
|
||||
var err error
|
||||
if f := o.Connect; f != nil {
|
||||
db, err = f(dsn, maxRetry)
|
||||
} else {
|
||||
db, err = connectToDatabase(dsn, maxRetry)
|
||||
}
|
||||
if err != nil {
|
||||
panic(err.Error() + " Open failed " + dsn)
|
||||
}
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer sqlDB.Close()
|
||||
sql := fmt.Sprintf(
|
||||
"CREATE DATABASE IF NOT EXISTS `%s` default charset utf8mb4 COLLATE utf8mb4_unicode_ci",
|
||||
o.Database,
|
||||
)
|
||||
err = db.Exec(sql).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("init db %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// connectToDatabase Connection retry for mysql.
|
||||
func connectToDatabase(dsn string, maxRetry int) (*gorm.DB, error) {
|
||||
var db *gorm.DB
|
||||
@ -106,7 +127,18 @@ func connectToDatabase(dsn string, maxRetry int) (*gorm.DB, error) {
|
||||
func NewGormDB() (*gorm.DB, error) {
|
||||
specialerror.AddReplace(gorm.ErrRecordNotFound, errs.ErrRecordNotFound)
|
||||
specialerror.AddErrHandler(replaceDuplicateKey)
|
||||
return newMysqlGormDB()
|
||||
|
||||
return newMysqlGormDB(&option{
|
||||
Username: config.Config.Mysql.Username,
|
||||
Password: config.Config.Mysql.Password,
|
||||
Address: config.Config.Mysql.Address,
|
||||
Database: config.Config.Mysql.Database,
|
||||
LogLevel: config.Config.Mysql.LogLevel,
|
||||
SlowThreshold: config.Config.Mysql.SlowThreshold,
|
||||
MaxLifeTime: config.Config.Mysql.MaxLifeTime,
|
||||
MaxOpenConn: config.Config.Mysql.MaxOpenConn,
|
||||
MaxIdleConn: config.Config.Mysql.MaxIdleConn,
|
||||
})
|
||||
}
|
||||
|
||||
func replaceDuplicateKey(err error) errs.CodeError {
|
||||
|
121
pkg/common/db/relation/mysql_init_test.go
Normal file
121
pkg/common/db/relation/mysql_init_test.go
Normal file
@ -0,0 +1,121 @@
|
||||
package relation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
func TestMaybeCreateTable(t *testing.T) {
|
||||
t.Run("normal", func(t *testing.T) {
|
||||
err := maybeCreateTable(&option{
|
||||
Username: "root",
|
||||
Password: "openIM123",
|
||||
Address: []string{"172.28.0.1:13306"},
|
||||
Database: "openIM_v3",
|
||||
LogLevel: 4,
|
||||
SlowThreshold: 500,
|
||||
MaxOpenConn: 1000,
|
||||
MaxIdleConn: 100,
|
||||
MaxLifeTime: 60,
|
||||
Connect: connect(expectExec{
|
||||
query: "CREATE DATABASE IF NOT EXISTS `openIM_v3` default charset utf8mb4 COLLATE utf8mb4_unicode_ci",
|
||||
args: nil,
|
||||
}),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("im-db", func(t *testing.T) {
|
||||
err := maybeCreateTable(&option{
|
||||
Username: "root",
|
||||
Password: "openIM123",
|
||||
Address: []string{"172.28.0.1:13306"},
|
||||
Database: "im-db",
|
||||
LogLevel: 4,
|
||||
SlowThreshold: 500,
|
||||
MaxOpenConn: 1000,
|
||||
MaxIdleConn: 100,
|
||||
MaxLifeTime: 60,
|
||||
Connect: connect(expectExec{
|
||||
query: "CREATE DATABASE IF NOT EXISTS `im-db` default charset utf8mb4 COLLATE utf8mb4_unicode_ci",
|
||||
args: nil,
|
||||
}),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("err", func(t *testing.T) {
|
||||
e := errors.New("e")
|
||||
err := maybeCreateTable(&option{
|
||||
Username: "root",
|
||||
Password: "openIM123",
|
||||
Address: []string{"172.28.0.1:13306"},
|
||||
Database: "openIM_v3",
|
||||
LogLevel: 4,
|
||||
SlowThreshold: 500,
|
||||
MaxOpenConn: 1000,
|
||||
MaxIdleConn: 100,
|
||||
MaxLifeTime: 60,
|
||||
Connect: connect(expectExec{
|
||||
err: e,
|
||||
}),
|
||||
})
|
||||
if !errors.Is(err, e) {
|
||||
t.Fatalf("err not is e: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func connect(e expectExec) func(string, int) (*gorm.DB, error) {
|
||||
return func(string, int) (*gorm.DB, error) {
|
||||
return gorm.Open(mysql.New(mysql.Config{
|
||||
SkipInitializeWithVersion: true,
|
||||
Conn: sql.OpenDB(e),
|
||||
}), &gorm.Config{
|
||||
Logger: logger.Discard,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type expectExec struct {
|
||||
err error
|
||||
query string
|
||||
args []driver.NamedValue
|
||||
}
|
||||
|
||||
func (c expectExec) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
|
||||
if c.err != nil {
|
||||
return nil, c.err
|
||||
}
|
||||
if query != c.query {
|
||||
return nil, fmt.Errorf("query mismatch. expect: %s, got: %s", c.query, query)
|
||||
}
|
||||
if reflect.DeepEqual(args, c.args) {
|
||||
return nil, fmt.Errorf("args mismatch. expect: %v, got: %v", c.args, args)
|
||||
}
|
||||
return noEffectResult{}, nil
|
||||
}
|
||||
|
||||
func (e expectExec) Connect(context.Context) (driver.Conn, error) { return e, nil }
|
||||
func (expectExec) Driver() driver.Driver { panic("not implemented") }
|
||||
func (expectExec) Prepare(query string) (driver.Stmt, error) { panic("not implemented") }
|
||||
func (expectExec) Close() (e error) { return }
|
||||
func (expectExec) Begin() (driver.Tx, error) { panic("not implemented") }
|
||||
|
||||
type noEffectResult struct{}
|
||||
|
||||
func (noEffectResult) LastInsertId() (i int64, e error) { return }
|
||||
func (noEffectResult) RowsAffected() (i int64, e error) { return }
|
Loading…
x
Reference in New Issue
Block a user