diff --git a/pkg/common/db/relation/mysql_init.go b/pkg/common/db/relation/mysql_init.go index 0e5ea5e43..16b8c99fa 100644 --- a/pkg/common/db/relation/mysql_init.go +++ b/pkg/common/db/relation/mysql_init.go @@ -18,15 +18,12 @@ import ( "fmt" "time" - mysqldriver "github.com/go-sql-driver/mysql" - "gorm.io/driver/mysql" - "github.com/OpenIMSDK/tools/errs" "github.com/OpenIMSDK/tools/log" "github.com/OpenIMSDK/tools/mw/specialerror" - + mysqldriver "github.com/go-sql-driver/mysql" "github.com/openimsdk/open-im-server/v3/pkg/common/config" - + "gorm.io/driver/mysql" "gorm.io/gorm" "gorm.io/gorm/logger" ) @@ -35,56 +32,80 @@ const ( maxRetry = 100 // number of retries ) -// newMysqlGormDB Initialize the database connection. -func newMysqlGormDB() (*gorm.DB, error) { - dsn := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=true&loc=Local", - config.Config.Mysql.Username, config.Config.Mysql.Password, config.Config.Mysql.Address[0], "mysql") +type option struct { + Username string + Password string + Address []string + Database string + LogLevel int + SlowThreshold int + MaxLifeTime int + MaxOpenConn int + MaxIdleConn int + Connect func(dsn string, maxRetry int) (*gorm.DB, error) +} - db, err := connectToDatabase(dsn, maxRetry) - if err != nil { - panic(err.Error() + " Open failed " + dsn) - } - sqlDB, err := db.DB() +// newMysqlGormDB Initialize the database connection. +func newMysqlGormDB(o *option) (*gorm.DB, error) { + err := maybeCreateTable(o) if err != nil { return nil, err } - defer sqlDB.Close() - sql := fmt.Sprintf( - "CREATE DATABASE IF NOT EXISTS %s default charset utf8mb4 COLLATE utf8mb4_unicode_ci;", - config.Config.Mysql.Database, - ) - err = db.Exec(sql).Error - if err != nil { - return nil, fmt.Errorf("init db %w", err) - } - dsn = fmt.Sprintf( - "%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=true&loc=Local", - config.Config.Mysql.Username, - config.Config.Mysql.Password, - config.Config.Mysql.Address[0], - config.Config.Mysql.Database, - ) + dsn := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=true&loc=Local", + o.Username, o.Password, o.Address[0], o.Database) sqlLogger := log.NewSqlLogger( - logger.LogLevel(config.Config.Mysql.LogLevel), + logger.LogLevel(o.LogLevel), true, - time.Duration(config.Config.Mysql.SlowThreshold)*time.Millisecond, + time.Duration(o.SlowThreshold)*time.Millisecond, ) - db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{ + db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{ Logger: sqlLogger, }) if err != nil { return nil, err } - sqlDB, err = db.DB() + sqlDB, err := db.DB() if err != nil { return nil, err } - sqlDB.SetConnMaxLifetime(time.Second * time.Duration(config.Config.Mysql.MaxLifeTime)) - sqlDB.SetMaxOpenConns(config.Config.Mysql.MaxOpenConn) - sqlDB.SetMaxIdleConns(config.Config.Mysql.MaxIdleConn) + sqlDB.SetConnMaxLifetime(time.Second * time.Duration(o.MaxLifeTime)) + sqlDB.SetMaxOpenConns(o.MaxOpenConn) + sqlDB.SetMaxIdleConns(o.MaxIdleConn) return db, nil } +// maybeCreateTable creates a database if it does not exists. +func maybeCreateTable(o *option) error { + dsn := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=true&loc=Local", + o.Username, o.Password, o.Address[0], "mysql") + + var db *gorm.DB + var err error + if f := o.Connect; f != nil { + db, err = f(dsn, maxRetry) + } else { + db, err = connectToDatabase(dsn, maxRetry) + } + if err != nil { + panic(err.Error() + " Open failed " + dsn) + } + + sqlDB, err := db.DB() + if err != nil { + return err + } + defer sqlDB.Close() + sql := fmt.Sprintf( + "CREATE DATABASE IF NOT EXISTS `%s` default charset utf8mb4 COLLATE utf8mb4_unicode_ci", + o.Database, + ) + err = db.Exec(sql).Error + if err != nil { + return fmt.Errorf("init db %w", err) + } + return nil +} + // connectToDatabase Connection retry for mysql. func connectToDatabase(dsn string, maxRetry int) (*gorm.DB, error) { var db *gorm.DB @@ -106,7 +127,18 @@ func connectToDatabase(dsn string, maxRetry int) (*gorm.DB, error) { func NewGormDB() (*gorm.DB, error) { specialerror.AddReplace(gorm.ErrRecordNotFound, errs.ErrRecordNotFound) specialerror.AddErrHandler(replaceDuplicateKey) - return newMysqlGormDB() + + return newMysqlGormDB(&option{ + Username: config.Config.Mysql.Username, + Password: config.Config.Mysql.Password, + Address: config.Config.Mysql.Address, + Database: config.Config.Mysql.Database, + LogLevel: config.Config.Mysql.LogLevel, + SlowThreshold: config.Config.Mysql.SlowThreshold, + MaxLifeTime: config.Config.Mysql.MaxLifeTime, + MaxOpenConn: config.Config.Mysql.MaxOpenConn, + MaxIdleConn: config.Config.Mysql.MaxIdleConn, + }) } func replaceDuplicateKey(err error) errs.CodeError { diff --git a/pkg/common/db/relation/mysql_init_test.go b/pkg/common/db/relation/mysql_init_test.go new file mode 100644 index 000000000..c321dfd9f --- /dev/null +++ b/pkg/common/db/relation/mysql_init_test.go @@ -0,0 +1,121 @@ +package relation + +import ( + "context" + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "reflect" + "testing" + + "gorm.io/driver/mysql" + "gorm.io/gorm" + "gorm.io/gorm/logger" +) + +func TestMaybeCreateTable(t *testing.T) { + t.Run("normal", func(t *testing.T) { + err := maybeCreateTable(&option{ + Username: "root", + Password: "openIM123", + Address: []string{"172.28.0.1:13306"}, + Database: "openIM_v3", + LogLevel: 4, + SlowThreshold: 500, + MaxOpenConn: 1000, + MaxIdleConn: 100, + MaxLifeTime: 60, + Connect: connect(expectExec{ + query: "CREATE DATABASE IF NOT EXISTS `openIM_v3` default charset utf8mb4 COLLATE utf8mb4_unicode_ci", + args: nil, + }), + }) + if err != nil { + t.Fatal(err) + } + }) + + t.Run("im-db", func(t *testing.T) { + err := maybeCreateTable(&option{ + Username: "root", + Password: "openIM123", + Address: []string{"172.28.0.1:13306"}, + Database: "im-db", + LogLevel: 4, + SlowThreshold: 500, + MaxOpenConn: 1000, + MaxIdleConn: 100, + MaxLifeTime: 60, + Connect: connect(expectExec{ + query: "CREATE DATABASE IF NOT EXISTS `im-db` default charset utf8mb4 COLLATE utf8mb4_unicode_ci", + args: nil, + }), + }) + if err != nil { + t.Fatal(err) + } + }) + + t.Run("err", func(t *testing.T) { + e := errors.New("e") + err := maybeCreateTable(&option{ + Username: "root", + Password: "openIM123", + Address: []string{"172.28.0.1:13306"}, + Database: "openIM_v3", + LogLevel: 4, + SlowThreshold: 500, + MaxOpenConn: 1000, + MaxIdleConn: 100, + MaxLifeTime: 60, + Connect: connect(expectExec{ + err: e, + }), + }) + if !errors.Is(err, e) { + t.Fatalf("err not is e: %v", err) + } + }) +} + +func connect(e expectExec) func(string, int) (*gorm.DB, error) { + return func(string, int) (*gorm.DB, error) { + return gorm.Open(mysql.New(mysql.Config{ + SkipInitializeWithVersion: true, + Conn: sql.OpenDB(e), + }), &gorm.Config{ + Logger: logger.Discard, + }) + } +} + +type expectExec struct { + err error + query string + args []driver.NamedValue +} + +func (c expectExec) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + if c.err != nil { + return nil, c.err + } + if query != c.query { + return nil, fmt.Errorf("query mismatch. expect: %s, got: %s", c.query, query) + } + if reflect.DeepEqual(args, c.args) { + return nil, fmt.Errorf("args mismatch. expect: %v, got: %v", c.args, args) + } + return noEffectResult{}, nil +} + +func (e expectExec) Connect(context.Context) (driver.Conn, error) { return e, nil } +func (expectExec) Driver() driver.Driver { panic("not implemented") } +func (expectExec) Prepare(query string) (driver.Stmt, error) { panic("not implemented") } +func (expectExec) Close() (e error) { return } +func (expectExec) Begin() (driver.Tx, error) { panic("not implemented") } + +type noEffectResult struct{} + +func (noEffectResult) LastInsertId() (i int64, e error) { return } +func (noEffectResult) RowsAffected() (i int64, e error) { return }