1
0
mirror of https://github.com/gogf/gf.git synced 2025-04-05 11:18:50 +08:00

improve function New for creating db by specified configuration node

This commit is contained in:
John Guo 2022-01-12 21:57:46 +08:00
parent 45fbb5326c
commit e5613e8690
12 changed files with 106 additions and 62 deletions

View File

@ -306,17 +306,25 @@ const (
SqlTypeStmtQueryRowContext = "DB.Statement.QueryRowContext"
)
const (
DriverNameMysql = `mysql`
DriverNameMssql = `mssql`
DriverNamePgsql = `pgsql`
DriverNameOracle = `oracle`
DriverNameSqlite = `sqlite`
)
var (
// instances is the management map for instances.
instances = gmap.NewStrAnyMap(true)
// driverMap manages all custom registered driver.
driverMap = map[string]Driver{
"mysql": &DriverMysql{},
"mssql": &DriverMssql{},
"pgsql": &DriverPgsql{},
"oracle": &DriverOracle{},
"sqlite": &DriverSqlite{},
DriverNameMysql: &DriverMysql{},
DriverNameMssql: &DriverMssql{},
DriverNamePgsql: &DriverPgsql{},
DriverNameOracle: &DriverOracle{},
DriverNameSqlite: &DriverSqlite{},
}
// lastOperatorRegPattern is the regular expression pattern for a string
@ -351,10 +359,15 @@ func Register(name string, driver Driver) error {
return nil
}
// New creates and returns an ORM object with global configurations.
// New creates and returns an ORM object with given configuration node.
func New(node ConfigNode) (db DB, err error) {
return doNewByNode(node, "")
}
// NewByGroup creates and returns an ORM object with global configurations.
// The parameter `name` specifies the configuration group name,
// which is DefaultGroupName in default.
func New(group ...string) (db DB, err error) {
func NewByGroup(group ...string) (db DB, err error) {
groupName := configs.group
if len(group) > 0 && group[0] != "" {
groupName = group[0]
@ -369,29 +382,9 @@ func New(group ...string) (db DB, err error) {
)
}
if _, ok := configs.config[groupName]; ok {
if node, err := getConfigNodeByGroup(groupName, true); err == nil {
c := &Core{
group: groupName,
debug: gtype.NewBool(),
cache: gcache.New(),
links: gmap.NewStrAnyMap(true),
schema: gtype.NewString(),
logger: glog.New(),
config: node,
}
if v, ok := driverMap[node.Type]; ok {
c.db, err = v.New(c, node)
if err != nil {
return nil, err
}
return c.db, nil
} else {
return nil, gerror.NewCodef(
gcode.CodeInvalidConfiguration,
`cannot find database driver for specified database type "%s", did you misspell type name "%s" or forget importing the database driver?`,
node.Type, node.Type,
)
}
var node *ConfigNode
if node, err = getConfigNodeByGroup(groupName, true); err == nil {
return doNewByNode(*node, groupName)
} else {
return nil, err
}
@ -404,6 +397,31 @@ func New(group ...string) (db DB, err error) {
}
}
// doNewByNode creates and returns an ORM object with given configuration node and group name.
func doNewByNode(node ConfigNode, group string) (db DB, err error) {
c := &Core{
group: group,
debug: gtype.NewBool(),
cache: gcache.New(),
links: gmap.NewStrAnyMap(true),
schema: gtype.NewString(),
logger: glog.New(),
config: &node,
}
if v, ok := driverMap[node.Type]; ok {
c.db, err = v.New(c, &node)
if err != nil {
return nil, err
}
return c.db, nil
}
return nil, gerror.NewCodef(
gcode.CodeInvalidConfiguration,
`cannot find database driver for specified database type "%s", did you misspell type name "%s" or forget importing the database driver?`,
node.Type, node.Type,
)
}
// Instance returns an instance for DB operations.
// The parameter `name` specifies the configuration group name,
// which is DefaultGroupName in default.
@ -413,7 +431,7 @@ func Instance(name ...string) (db DB, err error) {
group = name[0]
}
v := instances.GetOrSetFuncLock(group, func() interface{} {
db, err = New(group)
db, err = NewByGroup(group)
return db
})
if v != nil {
@ -502,14 +520,19 @@ func getConfigNodeByWeight(cg ConfigGroup) *ConfigNode {
return nil
}
// getSqlDb retrieves and returns a underlying database connection object.
// getSqlDb retrieves and returns an underlying database connection object.
// The parameter `master` specifies whether retrieves master node connection if
// master-slave nodes are configured.
func (c *Core) getSqlDb(master bool, schema ...string) (sqlDb *sql.DB, err error) {
// Load balance.
node, err := getConfigNodeByGroup(c.group, master)
if err != nil {
return nil, err
var node *ConfigNode
if c.group != "" {
node, err = getConfigNodeByGroup(c.group, master)
if err != nil {
return nil, err
}
} else {
node = c.config
}
// Default value checks.
if node.Charset == "" {

View File

@ -97,6 +97,9 @@ func (c *Core) GetCtxTimeout(timeoutType int, ctx context.Context) (context.Cont
// It is rare to Close a DB, as the DB handle is meant to be
// long-lived and shared between many goroutines.
func (c *Core) Close(ctx context.Context) (err error) {
if err = c.cache.Close(ctx); err != nil {
return err
}
c.links.LockFunc(func(m map[string]interface{}) {
for k, v := range m {
if db, ok := v.(*sql.DB); ok {

View File

@ -41,8 +41,8 @@ func (d *DriverMssql) New(core *Core, node *ConfigNode) (DB, error) {
// Open creates and returns an underlying sql.DB object for mssql.
func (d *DriverMssql) Open(config *ConfigNode) (db *sql.DB, err error) {
var (
source string
driver = "sqlserver"
source string
underlyingDriverName = "sqlserver"
)
if config.Link != "" {
source = config.Link
@ -53,10 +53,10 @@ func (d *DriverMssql) Open(config *ConfigNode) (db *sql.DB, err error) {
)
}
intlog.Printf(d.GetCtx(), "Open: %s", source)
if db, err = sql.Open(driver, source); err != nil {
if db, err = sql.Open(underlyingDriverName, source); err != nil {
err = gerror.WrapCodef(
gcode.CodeDbOperationError, err,
`sql.Open failed for driver "%s" by source "%s"`, driver, source,
`sql.Open failed for driver "%s" by source "%s"`, underlyingDriverName, source,
)
return nil, err
}

View File

@ -38,8 +38,8 @@ func (d *DriverMysql) New(core *Core, node *ConfigNode) (DB, error) {
// Note that it converts time.Time argument to local timezone in default.
func (d *DriverMysql) Open(config *ConfigNode) (db *sql.DB, err error) {
var (
source string
driver = "mysql"
source string
underlyingDriverName = "mysql"
)
if config.Link != "" {
source = config.Link
@ -57,10 +57,10 @@ func (d *DriverMysql) Open(config *ConfigNode) (db *sql.DB, err error) {
}
}
intlog.Printf(d.GetCtx(), "Open: %s", source)
if db, err = sql.Open(driver, source); err != nil {
if db, err = sql.Open(underlyingDriverName, source); err != nil {
err = gerror.WrapCodef(
gcode.CodeDbOperationError, err,
`sql.Open failed for driver "%s" by source "%s"`, driver, source,
`sql.Open failed for driver "%s" by source "%s"`, underlyingDriverName, source,
)
return nil, err
}

View File

@ -41,11 +41,11 @@ func (d *DriverOracle) New(core *Core, node *ConfigNode) (DB, error) {
}, nil
}
// Open creates and returns a underlying sql.DB object for oracle.
// Open creates and returns an underlying sql.DB object for oracle.
func (d *DriverOracle) Open(config *ConfigNode) (db *sql.DB, err error) {
var (
source string
driver = "oci8"
source string
underlyingDriverName = "oci8"
)
if config.Link != "" {
source = config.Link
@ -56,10 +56,10 @@ func (d *DriverOracle) Open(config *ConfigNode) (db *sql.DB, err error) {
)
}
intlog.Printf(d.GetCtx(), "Open: %s", source)
if db, err = sql.Open(driver, source); err != nil {
if db, err = sql.Open(underlyingDriverName, source); err != nil {
err = gerror.WrapCodef(
gcode.CodeDbOperationError, err,
`sql.Open failed for driver "%s" by source "%s"`, driver, source,
`sql.Open failed for driver "%s" by source "%s"`, underlyingDriverName, source,
)
return nil, err
}

View File

@ -41,8 +41,8 @@ func (d *DriverPgsql) New(core *Core, node *ConfigNode) (DB, error) {
// Open creates and returns an underlying sql.DB object for pgsql.
func (d *DriverPgsql) Open(config *ConfigNode) (db *sql.DB, err error) {
var (
source string
driver = "postgres"
source string
underlyingDriverName = "postgres"
)
if config.Link != "" {
source = config.Link
@ -56,10 +56,10 @@ func (d *DriverPgsql) Open(config *ConfigNode) (db *sql.DB, err error) {
}
}
intlog.Printf(d.GetCtx(), "Open: %s", source)
if db, err = sql.Open(driver, source); err != nil {
if db, err = sql.Open(underlyingDriverName, source); err != nil {
err = gerror.WrapCodef(
gcode.CodeDbOperationError, err,
`sql.Open failed for driver "%s" by source "%s"`, driver, source,
`sql.Open failed for driver "%s" by source "%s"`, underlyingDriverName, source,
)
return nil, err
}

View File

@ -39,8 +39,8 @@ func (d *DriverSqlite) New(core *Core, node *ConfigNode) (DB, error) {
// Open creates and returns a underlying sql.DB object for sqlite.
func (d *DriverSqlite) Open(config *ConfigNode) (db *sql.DB, err error) {
var (
source string
driver = "sqlite3"
source string
underlyingDriverName = "sqlite3"
)
if config.Link != "" {
source = config.Link
@ -52,10 +52,10 @@ func (d *DriverSqlite) Open(config *ConfigNode) (db *sql.DB, err error) {
source = absolutePath
}
intlog.Printf(d.GetCtx(), "Open: %s", source)
if db, err = sql.Open(driver, source); err != nil {
if db, err = sql.Open(underlyingDriverName, source); err != nil {
err = gerror.WrapCodef(
gcode.CodeDbOperationError, err,
`sql.Open failed for driver "%s" by source "%s"`, driver, source,
`sql.Open failed for driver "%s" by source "%s"`, underlyingDriverName, source,
)
return nil, err
}

View File

@ -42,7 +42,7 @@ func (s *Schema) Model(table string) *Model {
}
// Do not change the schema of the original db,
// it here creates a new db and changes its schema.
db, err := New(m.db.GetGroup())
db, err := NewByGroup(m.db.GetGroup())
if err != nil {
panic(err)
}

View File

@ -69,7 +69,7 @@ func init() {
gdb.AddConfigNode(gdb.DefaultGroupName, configNode)
// Default db.
if r, err := gdb.New(); err != nil {
if r, err := gdb.NewByGroup(); err != nil {
gtest.Error(err)
} else {
db = r
@ -84,7 +84,7 @@ func init() {
db.SetSchema(TestSchema1)
// Prefix db.
if r, err := gdb.New("prefix"); err != nil {
if r, err := gdb.NewByGroup("prefix"); err != nil {
gtest.Error(err)
} else {
dbPrefix = r
@ -98,7 +98,7 @@ func init() {
dbPrefix.SetSchema(TestSchema1)
// Invalid db.
if r, err := gdb.New("nodeinvalid"); err != nil {
if r, err := gdb.NewByGroup("nodeinvalid"); err != nil {
gtest.Error(err)
} else {
dbInvalid = r

View File

@ -22,6 +22,24 @@ import (
"github.com/gogf/gf/v2/text/gstr"
)
func Test_New(t *testing.T) {
gtest.C(t, func(t *gtest.T) {
node := gdb.ConfigNode{
Host: "127.0.0.1",
Port: "3306",
User: TestDbUser,
Pass: TestDbPass,
Type: gdb.DriverNameMysql,
}
newDb, err := gdb.New(node)
t.AssertNil(err)
value, err := newDb.GetValue(ctx, `select 1`)
t.AssertNil(err)
t.Assert(value, `1`)
t.AssertNil(newDb.Close(ctx))
})
}
func Test_DB_Ping(t *testing.T) {
gtest.C(t, func(t *gtest.T) {
err1 := db.PingMaster()

View File

@ -51,7 +51,7 @@ func init() {
}
AddConfigNode(DefaultGroupName, configNode)
// Default db.
if r, err := New(); err != nil {
if r, err := NewByGroup(); err != nil {
gtest.Error(err)
} else {
db = r

View File

@ -143,7 +143,7 @@ func Database(name ...string) gdb.DB {
}
// Create a new ORM object with given configurations.
if db, err := gdb.New(name...); err == nil {
if db, err := gdb.NewByGroup(name...); err == nil {
// Initialize logger for ORM.
var (
loggerConfigMap map[string]interface{}