1
0
mirror of https://github.com/gogf/gf.git synced 2025-04-05 03:05:05 +08:00

feat(database/gdb): add Raw support for Fields function of gdb.Model (#3873)

This commit is contained in:
John Guo 2024-10-21 09:22:31 +08:00 committed by GitHub
parent b1d875a31f
commit 7dd38a1700
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 127 additions and 71 deletions

View File

@ -4838,3 +4838,25 @@ func Test_OrderBy_Statement_Generated(t *testing.T) {
t.Assert(rawSql, expectSql)
})
}
func Test_Fields_Raw(t *testing.T) {
gtest.C(t, func(t *gtest.T) {
table := createInitTable()
defer dropTable(table)
one, err := db.Model(table).Fields(gdb.Raw("1")).One()
t.AssertNil(err)
t.Assert(one["1"], 1)
one, err = db.Model(table).Fields(gdb.Raw("2")).One()
t.AssertNil(err)
t.Assert(one["2"], 2)
one, err = db.Model(table).Fields(gdb.Raw("2")).Where("id", 2).One()
t.AssertNil(err)
t.Assert(one["2"], 2)
one, err = db.Model(table).Fields(gdb.Raw("2")).Where("id", 10000000000).One()
t.AssertNil(err)
t.Assert(len(one), 0)
})
}

View File

@ -650,6 +650,7 @@ func Test_Model_AllAndCount(t *testing.T) {
t.Assert(len(result), TableSize)
t.Assert(count, TableSize)
})
// AllAndCount with no data
gtest.C(t, func(t *gtest.T) {
result, count, err := db.Model(table).Where("id<0").AllAndCount(false)

View File

@ -344,8 +344,8 @@ func doQuoteString(s, charLeft, charRight string) string {
return gstr.Join(array1, ",")
}
func getFieldsFromStructOrMap(structOrMap interface{}) (fields []string) {
fields = []string{}
func getFieldsFromStructOrMap(structOrMap any) (fields []any) {
fields = []any{}
if utils.IsStruct(structOrMap) {
structFields, _ := gstructs.Fields(gstructs.FieldsInput{
Pointer: structOrMap,
@ -362,7 +362,7 @@ func getFieldsFromStructOrMap(structOrMap interface{}) (fields []string) {
}
}
} else {
fields = gutil.Keys(structOrMap)
fields = gconv.Interfaces(gutil.Keys(structOrMap))
}
return
}

View File

@ -24,8 +24,8 @@ type Model struct {
linkType int // Mark for operation on master or slave.
tablesInit string // Table names when model initialization.
tables string // Operation table names, which can be more than one table names and aliases, like: "user", "user u", "user u, user_detail ud".
fields string // Operation fields, multiple fields joined using char ','.
fieldsEx []string // Excluded operation fields, it here uses slice instead of string type for quick filtering.
fields []any // Operation fields, multiple fields joined using char ','.
fieldsEx []any // Excluded operation fields, it here uses slice instead of string type for quick filtering.
withArray []interface{} // Arguments for With feature.
withAll bool // Enable model association operations on all objects that have "with" tag in the struct.
extraArgs []interface{} // Extra custom arguments for sql, which are prepended to the arguments before sql committed to underlying driver.
@ -65,7 +65,7 @@ type ChunkHandler func(result Result, err error) bool
const (
linkTypeMaster = 1
linkTypeSlave = 2
defaultFields = "*"
defaultField = "*"
whereHolderOperatorWhere = 1
whereHolderOperatorAnd = 2
whereHolderOperatorOr = 3
@ -132,7 +132,6 @@ func (c *Core) Model(tableNameQueryOrStruct ...interface{}) *Model {
schema: c.schema,
tablesInit: tableStr,
tables: tableStr,
fields: defaultFields,
start: -1,
offset: -1,
filter: true,
@ -281,8 +280,12 @@ func (m *Model) Clone() *Model {
newModel.whereBuilder = m.whereBuilder.Clone()
newModel.whereBuilder.model = newModel
// Shallow copy slice attributes.
if n := len(m.fields); n > 0 {
newModel.fields = make([]any, n)
copy(newModel.fields, m.fields)
}
if n := len(m.fieldsEx); n > 0 {
newModel.fieldsEx = make([]string, n)
newModel.fieldsEx = make([]any, n)
copy(newModel.fieldsEx, m.fieldsEx)
}
if n := len(m.extraArgs); n > 0 {

View File

@ -33,7 +33,7 @@ func (m *Model) Fields(fieldNamesOrMapStruct ...interface{}) *Model {
return m
}
model := m.getModel()
return model.appendFieldsByStr(gstr.Join(fields, ","))
return model.appendToFields(fields...)
}
// FieldsPrefix performs as function Fields but add extra prefix for each field.
@ -45,9 +45,11 @@ func (m *Model) FieldsPrefix(prefixOrAlias string, fieldNamesOrMapStruct ...inte
if len(fields) == 0 {
return m
}
gstr.PrefixArray(fields, prefixOrAlias+".")
for i, field := range fields {
fields[i] = prefixOrAlias + "." + gconv.String(field)
}
model := m.getModel()
return model.appendFieldsByStr(gstr.Join(fields, ","))
return model.appendToFields(fields...)
}
// FieldsEx appends `fieldNamesOrMapStruct` to the excluded operation fields of the model,
@ -84,7 +86,9 @@ func (m *Model) FieldsExPrefix(prefixOrAlias string, fieldNamesOrMapStruct ...in
m.getTableNameByPrefixOrAlias(prefixOrAlias),
fieldNamesOrMapStruct...,
)
gstr.PrefixArray(model.fieldsEx, prefixOrAlias+".")
for i, field := range model.fieldsEx {
model.fieldsEx[i] = prefixOrAlias + "." + gconv.String(field)
}
return model
}
@ -95,7 +99,7 @@ func (m *Model) FieldCount(column string, as ...string) *Model {
asStr = fmt.Sprintf(` AS %s`, m.db.GetCore().QuoteWord(as[0]))
}
model := m.getModel()
return model.appendFieldsByStr(
return model.appendToFields(
fmt.Sprintf(`COUNT(%s)%s`, m.QuoteWord(column), asStr),
)
}
@ -107,7 +111,7 @@ func (m *Model) FieldSum(column string, as ...string) *Model {
asStr = fmt.Sprintf(` AS %s`, m.db.GetCore().QuoteWord(as[0]))
}
model := m.getModel()
return model.appendFieldsByStr(
return model.appendToFields(
fmt.Sprintf(`SUM(%s)%s`, m.QuoteWord(column), asStr),
)
}
@ -119,7 +123,7 @@ func (m *Model) FieldMin(column string, as ...string) *Model {
asStr = fmt.Sprintf(` AS %s`, m.db.GetCore().QuoteWord(as[0]))
}
model := m.getModel()
return model.appendFieldsByStr(
return model.appendToFields(
fmt.Sprintf(`MIN(%s)%s`, m.QuoteWord(column), asStr),
)
}
@ -131,7 +135,7 @@ func (m *Model) FieldMax(column string, as ...string) *Model {
asStr = fmt.Sprintf(` AS %s`, m.db.GetCore().QuoteWord(as[0]))
}
model := m.getModel()
return model.appendFieldsByStr(
return model.appendToFields(
fmt.Sprintf(`MAX(%s)%s`, m.QuoteWord(column), asStr),
)
}
@ -143,7 +147,7 @@ func (m *Model) FieldAvg(column string, as ...string) *Model {
asStr = fmt.Sprintf(` AS %s`, m.db.GetCore().QuoteWord(as[0]))
}
model := m.getModel()
return model.appendFieldsByStr(
return model.appendToFields(
fmt.Sprintf(`AVG(%s)%s`, m.QuoteWord(column), asStr),
)
}
@ -218,7 +222,7 @@ func (m *Model) HasField(field string) (bool, error) {
}
// getFieldsFrom retrieves, filters and returns fields name from table `table`.
func (m *Model) filterFieldsFrom(table string, fieldNamesOrMapStruct ...interface{}) []string {
func (m *Model) filterFieldsFrom(table string, fieldNamesOrMapStruct ...any) []any {
length := len(fieldNamesOrMapStruct)
if length == 0 {
return nil
@ -227,7 +231,7 @@ func (m *Model) filterFieldsFrom(table string, fieldNamesOrMapStruct ...interfac
// String slice.
case length >= 2:
return m.mappingAndFilterToTableFields(
table, gconv.Strings(fieldNamesOrMapStruct), true,
table, fieldNamesOrMapStruct, true,
)
// It needs type asserting.
@ -235,13 +239,13 @@ func (m *Model) filterFieldsFrom(table string, fieldNamesOrMapStruct ...interfac
structOrMap := fieldNamesOrMapStruct[0]
switch r := structOrMap.(type) {
case string:
return m.mappingAndFilterToTableFields(table, []string{r}, false)
return m.mappingAndFilterToTableFields(table, []any{r}, false)
case []string:
return m.mappingAndFilterToTableFields(table, r, true)
return m.mappingAndFilterToTableFields(table, gconv.Interfaces(r), true)
case Raw, *Raw:
return []string{gconv.String(structOrMap)}
return []any{structOrMap}
default:
return m.mappingAndFilterToTableFields(table, getFieldsFromStructOrMap(structOrMap), true)
@ -252,19 +256,13 @@ func (m *Model) filterFieldsFrom(table string, fieldNamesOrMapStruct ...interfac
}
}
func (m *Model) appendFieldsByStr(fields string) *Model {
if fields != "" {
model := m.getModel()
if model.fields == defaultFields {
model.fields = ""
}
if model.fields != "" {
model.fields += ","
}
model.fields += fields
return model
func (m *Model) appendToFields(fields ...any) *Model {
if len(fields) == 0 {
return m
}
return m
model := m.getModel()
model.fields = append(model.fields, fields...)
return model
}
func (m *Model) isFieldInFieldsEx(field string) bool {

View File

@ -54,7 +54,7 @@ func (m *Model) AllAndCount(useFieldForCount bool) (result Result, totalCount in
// If useFieldForCount is false, set the fields to a constant value of 1 for counting
if !useFieldForCount {
countModel.fields = "1"
countModel.fields = []any{Raw("1")}
}
// Get the total count of records
@ -178,7 +178,7 @@ func (m *Model) Array(fieldsAndWhere ...interface{}) ([]Value, error) {
func (m *Model) doStruct(pointer interface{}, where ...interface{}) error {
model := m
// Auto selecting fields by struct attributes.
if len(model.fieldsEx) == 0 && (model.fields == "" || model.fields == "*") {
if len(model.fieldsEx) == 0 && len(model.fields) == 0 {
if v, ok := pointer.(reflect.Value); ok {
model = m.Fields(v.Interface())
} else {
@ -214,7 +214,7 @@ func (m *Model) doStruct(pointer interface{}, where ...interface{}) error {
func (m *Model) doStructs(pointer interface{}, where ...interface{}) error {
model := m
// Auto selecting fields by struct attributes.
if len(model.fieldsEx) == 0 && (model.fields == "" || model.fields == "*") {
if len(model.fieldsEx) == 0 && len(model.fields) == 0 {
if v, ok := pointer.(reflect.Value); ok {
model = m.Fields(
reflect.New(
@ -316,7 +316,7 @@ func (m *Model) ScanAndCount(pointer interface{}, totalCount *int, useFieldForCo
countModel := m.Clone()
// If useFieldForCount is false, set the fields to a constant value of 1 for counting
if !useFieldForCount {
countModel.fields = "1"
countModel.fields = []any{Raw("1")}
}
// Get the total count of records
@ -343,7 +343,7 @@ func (m *Model) ScanList(structSlicePointer interface{}, bindToAttrName string,
if err != nil {
return err
}
if m.fields != defaultFields || len(m.fieldsEx) != 0 {
if len(m.fields) > 0 || len(m.fieldsEx) != 0 {
// There are custom fields.
result, err = m.All()
} else {
@ -604,7 +604,9 @@ func (m *Model) doGetAll(ctx context.Context, limit1 bool, where ...interface{})
}
// doGetAllBySql does the select statement on the database.
func (m *Model) doGetAllBySql(ctx context.Context, queryType queryType, sql string, args ...interface{}) (result Result, err error) {
func (m *Model) doGetAllBySql(
ctx context.Context, queryType queryType, sql string, args ...interface{},
) (result Result, err error) {
if result, err = m.getSelectResultFromCache(ctx, sql, args...); err != nil || result != nil {
return
}
@ -635,10 +637,10 @@ func (m *Model) getFormattedSqlAndArgs(
switch queryType {
case queryTypeCount:
queryFields := "COUNT(1)"
if m.fields != "" && m.fields != "*" {
if len(m.fields) > 0 {
// DO NOT quote the m.fields here, in case of fields like:
// DISTINCT t.user_id uid
queryFields = fmt.Sprintf(`COUNT(%s%s)`, m.distinct, m.fields)
queryFields = fmt.Sprintf(`COUNT(%s%s)`, m.distinct, m.getFieldsAsStr())
}
// Raw SQL Model.
if m.rawSql != "" {
@ -691,29 +693,50 @@ func (m *Model) getAutoPrefix() string {
return autoPrefix
}
func (m *Model) getFieldsAsStr() string {
var (
fieldsStr string
core = m.db.GetCore()
)
for _, v := range m.fields {
field := gconv.String(v)
switch {
case gstr.ContainsAny(field, "()"):
case gstr.ContainsAny(field, ". "):
default:
switch v.(type) {
case Raw, *Raw:
default:
field = core.QuoteString(field)
}
}
if fieldsStr != "" {
fieldsStr += ","
}
fieldsStr += field
}
return fieldsStr
}
// getFieldsFiltered checks the fields and fieldsEx attributes, filters and returns the fields that will
// really be committed to underlying database driver.
func (m *Model) getFieldsFiltered() string {
if len(m.fieldsEx) == 0 {
// No filtering, containing special chars.
if gstr.ContainsAny(m.fields, "()") {
return m.fields
}
// No filtering.
if !gstr.ContainsAny(m.fields, ". ") {
return m.db.GetCore().QuoteString(m.fields)
}
return m.fields
if len(m.fieldsEx) == 0 && len(m.fields) == 0 {
return defaultField
}
if len(m.fieldsEx) == 0 && len(m.fields) > 0 {
return m.getFieldsAsStr()
}
var (
fieldsArray []string
fieldsExSet = gset.NewStrSetFrom(m.fieldsEx)
fieldsExSet = gset.NewStrSetFrom(gconv.Strings(m.fieldsEx))
)
if m.fields != "*" {
if len(m.fields) > 0 {
// Filter custom fields with fieldEx.
fieldsArray = make([]string, 0, 8)
for _, v := range gstr.SplitAndTrim(m.fields, ",") {
fieldsArray = append(fieldsArray, v[gstr.PosR(v, "-")+1:])
for _, v := range m.fields {
field := gconv.String(v)
fieldsArray = append(fieldsArray, field[gstr.PosR(field, "-")+1:])
}
} else {
if gstr.Contains(m.tables, " ") {

View File

@ -14,6 +14,7 @@ import (
"github.com/gogf/gf/v2/os/gtime"
"github.com/gogf/gf/v2/text/gregex"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gconv"
"github.com/gogf/gf/v2/util/gutil"
)
@ -52,7 +53,7 @@ func (m *Model) getModel() *Model {
// Eg:
// ID -> id
// NICK_Name -> nickname.
func (m *Model) mappingAndFilterToTableFields(table string, fields []string, filter bool) []string {
func (m *Model) mappingAndFilterToTableFields(table string, fields []any, filter bool) []any {
var fieldsTable = table
if fieldsTable != "" {
hasTable, _ := m.db.GetCore().HasTable(fieldsTable)
@ -68,18 +69,24 @@ func (m *Model) mappingAndFilterToTableFields(table string, fields []string, fil
if len(fieldsMap) == 0 {
return fields
}
var outputFieldsArray = make([]string, 0)
var outputFieldsArray = make([]any, 0)
fieldsKeyMap := make(map[string]interface{}, len(fieldsMap))
for k := range fieldsMap {
fieldsKeyMap[k] = nil
}
for _, field := range fields {
var inputFieldsArray []string
if gregex.IsMatchString(regularFieldNameWithoutDotRegPattern, field) {
inputFieldsArray = append(inputFieldsArray, field)
} else if gregex.IsMatchString(regularFieldNameWithCommaRegPattern, field) {
inputFieldsArray = gstr.SplitAndTrim(field, ",")
} else {
var (
fieldStr = gconv.String(field)
inputFieldsArray []string
)
switch {
case gregex.IsMatchString(regularFieldNameWithoutDotRegPattern, fieldStr):
inputFieldsArray = append(inputFieldsArray, fieldStr)
case gregex.IsMatchString(regularFieldNameWithCommaRegPattern, fieldStr):
inputFieldsArray = gstr.SplitAndTrim(fieldStr, ",")
default:
// Example:
// user.id, user.name
// replace(concat_ws(',',lpad(s.id, 6, '0'),s.name),',','') `code`
@ -186,26 +193,26 @@ func (m *Model) doMappingAndFilterForInsertOrUpdateDataMap(data Map, allowOmitEm
data = tempMap
}
if len(m.fields) > 0 && m.fields != "*" {
if len(m.fields) > 0 {
// Keep specified fields.
var (
set = gset.NewStrSetFrom(gstr.SplitAndTrim(m.fields, ","))
fieldSet = gset.NewStrSetFrom(gconv.Strings(m.fields))
charL, charR = m.db.GetChars()
chars = charL + charR
)
set.Walk(func(item string) string {
fieldSet.Walk(func(item string) string {
return gstr.Trim(item, chars)
})
for k := range data {
k = gstr.Trim(k, chars)
if !set.Contains(k) {
if !fieldSet.Contains(k) {
delete(data, k)
}
}
} else if len(m.fieldsEx) > 0 {
// Filter specified fields.
for _, v := range m.fieldsEx {
delete(data, v)
delete(data, gconv.String(v))
}
}
return data, nil

View File

@ -18,9 +18,9 @@ const (
)
// Keys retrieves and returns the keys from given map or struct.
func Keys(mapOrStruct interface{}) (keysOrAttrs []string) {
func Keys(mapOrStruct any) (keysOrAttrs []string) {
keysOrAttrs = make([]string, 0)
if m, ok := mapOrStruct.(map[string]interface{}); ok {
if m, ok := mapOrStruct.(map[string]any); ok {
for k := range m {
keysOrAttrs = append(keysOrAttrs, k)
}
@ -63,6 +63,7 @@ func Keys(mapOrStruct interface{}) (keysOrAttrs []string) {
keysOrAttrs = append(keysOrAttrs, fieldType.Name)
}
}
default:
}
return
}
@ -108,6 +109,7 @@ func Values(mapOrStruct interface{}) (values []interface{}) {
values = append(values, reflectValue.Field(i).Interface())
}
}
default:
}
return
}