diff --git a/contrib/drivers/mysql/mysql_z_unit_model_test.go b/contrib/drivers/mysql/mysql_z_unit_model_test.go index 83c348b1a..652f69dad 100644 --- a/contrib/drivers/mysql/mysql_z_unit_model_test.go +++ b/contrib/drivers/mysql/mysql_z_unit_model_test.go @@ -4838,3 +4838,25 @@ func Test_OrderBy_Statement_Generated(t *testing.T) { t.Assert(rawSql, expectSql) }) } + +func Test_Fields_Raw(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + table := createInitTable() + defer dropTable(table) + one, err := db.Model(table).Fields(gdb.Raw("1")).One() + t.AssertNil(err) + t.Assert(one["1"], 1) + + one, err = db.Model(table).Fields(gdb.Raw("2")).One() + t.AssertNil(err) + t.Assert(one["2"], 2) + + one, err = db.Model(table).Fields(gdb.Raw("2")).Where("id", 2).One() + t.AssertNil(err) + t.Assert(one["2"], 2) + + one, err = db.Model(table).Fields(gdb.Raw("2")).Where("id", 10000000000).One() + t.AssertNil(err) + t.Assert(len(one), 0) + }) +} diff --git a/contrib/drivers/sqlite/sqlite_z_unit_model_test.go b/contrib/drivers/sqlite/sqlite_z_unit_model_test.go index dd0eaba96..c9e78d798 100644 --- a/contrib/drivers/sqlite/sqlite_z_unit_model_test.go +++ b/contrib/drivers/sqlite/sqlite_z_unit_model_test.go @@ -650,6 +650,7 @@ func Test_Model_AllAndCount(t *testing.T) { t.Assert(len(result), TableSize) t.Assert(count, TableSize) }) + // AllAndCount with no data gtest.C(t, func(t *gtest.T) { result, count, err := db.Model(table).Where("id<0").AllAndCount(false) diff --git a/database/gdb/gdb_func.go b/database/gdb/gdb_func.go index 9064f4a2e..752710d1c 100644 --- a/database/gdb/gdb_func.go +++ b/database/gdb/gdb_func.go @@ -344,8 +344,8 @@ func doQuoteString(s, charLeft, charRight string) string { return gstr.Join(array1, ",") } -func getFieldsFromStructOrMap(structOrMap interface{}) (fields []string) { - fields = []string{} +func getFieldsFromStructOrMap(structOrMap any) (fields []any) { + fields = []any{} if utils.IsStruct(structOrMap) { structFields, _ := gstructs.Fields(gstructs.FieldsInput{ Pointer: structOrMap, @@ -362,7 +362,7 @@ func getFieldsFromStructOrMap(structOrMap interface{}) (fields []string) { } } } else { - fields = gutil.Keys(structOrMap) + fields = gconv.Interfaces(gutil.Keys(structOrMap)) } return } diff --git a/database/gdb/gdb_model.go b/database/gdb/gdb_model.go index d2c812ccb..beaa51500 100644 --- a/database/gdb/gdb_model.go +++ b/database/gdb/gdb_model.go @@ -24,8 +24,8 @@ type Model struct { linkType int // Mark for operation on master or slave. tablesInit string // Table names when model initialization. tables string // Operation table names, which can be more than one table names and aliases, like: "user", "user u", "user u, user_detail ud". - fields string // Operation fields, multiple fields joined using char ','. - fieldsEx []string // Excluded operation fields, it here uses slice instead of string type for quick filtering. + fields []any // Operation fields, multiple fields joined using char ','. + fieldsEx []any // Excluded operation fields, it here uses slice instead of string type for quick filtering. withArray []interface{} // Arguments for With feature. withAll bool // Enable model association operations on all objects that have "with" tag in the struct. extraArgs []interface{} // Extra custom arguments for sql, which are prepended to the arguments before sql committed to underlying driver. @@ -65,7 +65,7 @@ type ChunkHandler func(result Result, err error) bool const ( linkTypeMaster = 1 linkTypeSlave = 2 - defaultFields = "*" + defaultField = "*" whereHolderOperatorWhere = 1 whereHolderOperatorAnd = 2 whereHolderOperatorOr = 3 @@ -132,7 +132,6 @@ func (c *Core) Model(tableNameQueryOrStruct ...interface{}) *Model { schema: c.schema, tablesInit: tableStr, tables: tableStr, - fields: defaultFields, start: -1, offset: -1, filter: true, @@ -281,8 +280,12 @@ func (m *Model) Clone() *Model { newModel.whereBuilder = m.whereBuilder.Clone() newModel.whereBuilder.model = newModel // Shallow copy slice attributes. + if n := len(m.fields); n > 0 { + newModel.fields = make([]any, n) + copy(newModel.fields, m.fields) + } if n := len(m.fieldsEx); n > 0 { - newModel.fieldsEx = make([]string, n) + newModel.fieldsEx = make([]any, n) copy(newModel.fieldsEx, m.fieldsEx) } if n := len(m.extraArgs); n > 0 { diff --git a/database/gdb/gdb_model_fields.go b/database/gdb/gdb_model_fields.go index 1615090da..9f7e2cc7c 100644 --- a/database/gdb/gdb_model_fields.go +++ b/database/gdb/gdb_model_fields.go @@ -33,7 +33,7 @@ func (m *Model) Fields(fieldNamesOrMapStruct ...interface{}) *Model { return m } model := m.getModel() - return model.appendFieldsByStr(gstr.Join(fields, ",")) + return model.appendToFields(fields...) } // FieldsPrefix performs as function Fields but add extra prefix for each field. @@ -45,9 +45,11 @@ func (m *Model) FieldsPrefix(prefixOrAlias string, fieldNamesOrMapStruct ...inte if len(fields) == 0 { return m } - gstr.PrefixArray(fields, prefixOrAlias+".") + for i, field := range fields { + fields[i] = prefixOrAlias + "." + gconv.String(field) + } model := m.getModel() - return model.appendFieldsByStr(gstr.Join(fields, ",")) + return model.appendToFields(fields...) } // FieldsEx appends `fieldNamesOrMapStruct` to the excluded operation fields of the model, @@ -84,7 +86,9 @@ func (m *Model) FieldsExPrefix(prefixOrAlias string, fieldNamesOrMapStruct ...in m.getTableNameByPrefixOrAlias(prefixOrAlias), fieldNamesOrMapStruct..., ) - gstr.PrefixArray(model.fieldsEx, prefixOrAlias+".") + for i, field := range model.fieldsEx { + model.fieldsEx[i] = prefixOrAlias + "." + gconv.String(field) + } return model } @@ -95,7 +99,7 @@ func (m *Model) FieldCount(column string, as ...string) *Model { asStr = fmt.Sprintf(` AS %s`, m.db.GetCore().QuoteWord(as[0])) } model := m.getModel() - return model.appendFieldsByStr( + return model.appendToFields( fmt.Sprintf(`COUNT(%s)%s`, m.QuoteWord(column), asStr), ) } @@ -107,7 +111,7 @@ func (m *Model) FieldSum(column string, as ...string) *Model { asStr = fmt.Sprintf(` AS %s`, m.db.GetCore().QuoteWord(as[0])) } model := m.getModel() - return model.appendFieldsByStr( + return model.appendToFields( fmt.Sprintf(`SUM(%s)%s`, m.QuoteWord(column), asStr), ) } @@ -119,7 +123,7 @@ func (m *Model) FieldMin(column string, as ...string) *Model { asStr = fmt.Sprintf(` AS %s`, m.db.GetCore().QuoteWord(as[0])) } model := m.getModel() - return model.appendFieldsByStr( + return model.appendToFields( fmt.Sprintf(`MIN(%s)%s`, m.QuoteWord(column), asStr), ) } @@ -131,7 +135,7 @@ func (m *Model) FieldMax(column string, as ...string) *Model { asStr = fmt.Sprintf(` AS %s`, m.db.GetCore().QuoteWord(as[0])) } model := m.getModel() - return model.appendFieldsByStr( + return model.appendToFields( fmt.Sprintf(`MAX(%s)%s`, m.QuoteWord(column), asStr), ) } @@ -143,7 +147,7 @@ func (m *Model) FieldAvg(column string, as ...string) *Model { asStr = fmt.Sprintf(` AS %s`, m.db.GetCore().QuoteWord(as[0])) } model := m.getModel() - return model.appendFieldsByStr( + return model.appendToFields( fmt.Sprintf(`AVG(%s)%s`, m.QuoteWord(column), asStr), ) } @@ -218,7 +222,7 @@ func (m *Model) HasField(field string) (bool, error) { } // getFieldsFrom retrieves, filters and returns fields name from table `table`. -func (m *Model) filterFieldsFrom(table string, fieldNamesOrMapStruct ...interface{}) []string { +func (m *Model) filterFieldsFrom(table string, fieldNamesOrMapStruct ...any) []any { length := len(fieldNamesOrMapStruct) if length == 0 { return nil @@ -227,7 +231,7 @@ func (m *Model) filterFieldsFrom(table string, fieldNamesOrMapStruct ...interfac // String slice. case length >= 2: return m.mappingAndFilterToTableFields( - table, gconv.Strings(fieldNamesOrMapStruct), true, + table, fieldNamesOrMapStruct, true, ) // It needs type asserting. @@ -235,13 +239,13 @@ func (m *Model) filterFieldsFrom(table string, fieldNamesOrMapStruct ...interfac structOrMap := fieldNamesOrMapStruct[0] switch r := structOrMap.(type) { case string: - return m.mappingAndFilterToTableFields(table, []string{r}, false) + return m.mappingAndFilterToTableFields(table, []any{r}, false) case []string: - return m.mappingAndFilterToTableFields(table, r, true) + return m.mappingAndFilterToTableFields(table, gconv.Interfaces(r), true) case Raw, *Raw: - return []string{gconv.String(structOrMap)} + return []any{structOrMap} default: return m.mappingAndFilterToTableFields(table, getFieldsFromStructOrMap(structOrMap), true) @@ -252,19 +256,13 @@ func (m *Model) filterFieldsFrom(table string, fieldNamesOrMapStruct ...interfac } } -func (m *Model) appendFieldsByStr(fields string) *Model { - if fields != "" { - model := m.getModel() - if model.fields == defaultFields { - model.fields = "" - } - if model.fields != "" { - model.fields += "," - } - model.fields += fields - return model +func (m *Model) appendToFields(fields ...any) *Model { + if len(fields) == 0 { + return m } - return m + model := m.getModel() + model.fields = append(model.fields, fields...) + return model } func (m *Model) isFieldInFieldsEx(field string) bool { diff --git a/database/gdb/gdb_model_select.go b/database/gdb/gdb_model_select.go index ca047111f..018b37b0a 100644 --- a/database/gdb/gdb_model_select.go +++ b/database/gdb/gdb_model_select.go @@ -54,7 +54,7 @@ func (m *Model) AllAndCount(useFieldForCount bool) (result Result, totalCount in // If useFieldForCount is false, set the fields to a constant value of 1 for counting if !useFieldForCount { - countModel.fields = "1" + countModel.fields = []any{Raw("1")} } // Get the total count of records @@ -178,7 +178,7 @@ func (m *Model) Array(fieldsAndWhere ...interface{}) ([]Value, error) { func (m *Model) doStruct(pointer interface{}, where ...interface{}) error { model := m // Auto selecting fields by struct attributes. - if len(model.fieldsEx) == 0 && (model.fields == "" || model.fields == "*") { + if len(model.fieldsEx) == 0 && len(model.fields) == 0 { if v, ok := pointer.(reflect.Value); ok { model = m.Fields(v.Interface()) } else { @@ -214,7 +214,7 @@ func (m *Model) doStruct(pointer interface{}, where ...interface{}) error { func (m *Model) doStructs(pointer interface{}, where ...interface{}) error { model := m // Auto selecting fields by struct attributes. - if len(model.fieldsEx) == 0 && (model.fields == "" || model.fields == "*") { + if len(model.fieldsEx) == 0 && len(model.fields) == 0 { if v, ok := pointer.(reflect.Value); ok { model = m.Fields( reflect.New( @@ -316,7 +316,7 @@ func (m *Model) ScanAndCount(pointer interface{}, totalCount *int, useFieldForCo countModel := m.Clone() // If useFieldForCount is false, set the fields to a constant value of 1 for counting if !useFieldForCount { - countModel.fields = "1" + countModel.fields = []any{Raw("1")} } // Get the total count of records @@ -343,7 +343,7 @@ func (m *Model) ScanList(structSlicePointer interface{}, bindToAttrName string, if err != nil { return err } - if m.fields != defaultFields || len(m.fieldsEx) != 0 { + if len(m.fields) > 0 || len(m.fieldsEx) != 0 { // There are custom fields. result, err = m.All() } else { @@ -604,7 +604,9 @@ func (m *Model) doGetAll(ctx context.Context, limit1 bool, where ...interface{}) } // doGetAllBySql does the select statement on the database. -func (m *Model) doGetAllBySql(ctx context.Context, queryType queryType, sql string, args ...interface{}) (result Result, err error) { +func (m *Model) doGetAllBySql( + ctx context.Context, queryType queryType, sql string, args ...interface{}, +) (result Result, err error) { if result, err = m.getSelectResultFromCache(ctx, sql, args...); err != nil || result != nil { return } @@ -635,10 +637,10 @@ func (m *Model) getFormattedSqlAndArgs( switch queryType { case queryTypeCount: queryFields := "COUNT(1)" - if m.fields != "" && m.fields != "*" { + if len(m.fields) > 0 { // DO NOT quote the m.fields here, in case of fields like: // DISTINCT t.user_id uid - queryFields = fmt.Sprintf(`COUNT(%s%s)`, m.distinct, m.fields) + queryFields = fmt.Sprintf(`COUNT(%s%s)`, m.distinct, m.getFieldsAsStr()) } // Raw SQL Model. if m.rawSql != "" { @@ -691,29 +693,50 @@ func (m *Model) getAutoPrefix() string { return autoPrefix } +func (m *Model) getFieldsAsStr() string { + var ( + fieldsStr string + core = m.db.GetCore() + ) + for _, v := range m.fields { + field := gconv.String(v) + switch { + case gstr.ContainsAny(field, "()"): + case gstr.ContainsAny(field, ". "): + default: + switch v.(type) { + case Raw, *Raw: + default: + field = core.QuoteString(field) + } + } + if fieldsStr != "" { + fieldsStr += "," + } + fieldsStr += field + } + return fieldsStr +} + // getFieldsFiltered checks the fields and fieldsEx attributes, filters and returns the fields that will // really be committed to underlying database driver. func (m *Model) getFieldsFiltered() string { - if len(m.fieldsEx) == 0 { - // No filtering, containing special chars. - if gstr.ContainsAny(m.fields, "()") { - return m.fields - } - // No filtering. - if !gstr.ContainsAny(m.fields, ". ") { - return m.db.GetCore().QuoteString(m.fields) - } - return m.fields + if len(m.fieldsEx) == 0 && len(m.fields) == 0 { + return defaultField + } + if len(m.fieldsEx) == 0 && len(m.fields) > 0 { + return m.getFieldsAsStr() } var ( fieldsArray []string - fieldsExSet = gset.NewStrSetFrom(m.fieldsEx) + fieldsExSet = gset.NewStrSetFrom(gconv.Strings(m.fieldsEx)) ) - if m.fields != "*" { + if len(m.fields) > 0 { // Filter custom fields with fieldEx. fieldsArray = make([]string, 0, 8) - for _, v := range gstr.SplitAndTrim(m.fields, ",") { - fieldsArray = append(fieldsArray, v[gstr.PosR(v, "-")+1:]) + for _, v := range m.fields { + field := gconv.String(v) + fieldsArray = append(fieldsArray, field[gstr.PosR(field, "-")+1:]) } } else { if gstr.Contains(m.tables, " ") { diff --git a/database/gdb/gdb_model_utility.go b/database/gdb/gdb_model_utility.go index 1b8d3704f..79fe878ac 100644 --- a/database/gdb/gdb_model_utility.go +++ b/database/gdb/gdb_model_utility.go @@ -14,6 +14,7 @@ import ( "github.com/gogf/gf/v2/os/gtime" "github.com/gogf/gf/v2/text/gregex" "github.com/gogf/gf/v2/text/gstr" + "github.com/gogf/gf/v2/util/gconv" "github.com/gogf/gf/v2/util/gutil" ) @@ -52,7 +53,7 @@ func (m *Model) getModel() *Model { // Eg: // ID -> id // NICK_Name -> nickname. -func (m *Model) mappingAndFilterToTableFields(table string, fields []string, filter bool) []string { +func (m *Model) mappingAndFilterToTableFields(table string, fields []any, filter bool) []any { var fieldsTable = table if fieldsTable != "" { hasTable, _ := m.db.GetCore().HasTable(fieldsTable) @@ -68,18 +69,24 @@ func (m *Model) mappingAndFilterToTableFields(table string, fields []string, fil if len(fieldsMap) == 0 { return fields } - var outputFieldsArray = make([]string, 0) + var outputFieldsArray = make([]any, 0) fieldsKeyMap := make(map[string]interface{}, len(fieldsMap)) for k := range fieldsMap { fieldsKeyMap[k] = nil } for _, field := range fields { - var inputFieldsArray []string - if gregex.IsMatchString(regularFieldNameWithoutDotRegPattern, field) { - inputFieldsArray = append(inputFieldsArray, field) - } else if gregex.IsMatchString(regularFieldNameWithCommaRegPattern, field) { - inputFieldsArray = gstr.SplitAndTrim(field, ",") - } else { + var ( + fieldStr = gconv.String(field) + inputFieldsArray []string + ) + switch { + case gregex.IsMatchString(regularFieldNameWithoutDotRegPattern, fieldStr): + inputFieldsArray = append(inputFieldsArray, fieldStr) + + case gregex.IsMatchString(regularFieldNameWithCommaRegPattern, fieldStr): + inputFieldsArray = gstr.SplitAndTrim(fieldStr, ",") + + default: // Example: // user.id, user.name // replace(concat_ws(',',lpad(s.id, 6, '0'),s.name),',','') `code` @@ -186,26 +193,26 @@ func (m *Model) doMappingAndFilterForInsertOrUpdateDataMap(data Map, allowOmitEm data = tempMap } - if len(m.fields) > 0 && m.fields != "*" { + if len(m.fields) > 0 { // Keep specified fields. var ( - set = gset.NewStrSetFrom(gstr.SplitAndTrim(m.fields, ",")) + fieldSet = gset.NewStrSetFrom(gconv.Strings(m.fields)) charL, charR = m.db.GetChars() chars = charL + charR ) - set.Walk(func(item string) string { + fieldSet.Walk(func(item string) string { return gstr.Trim(item, chars) }) for k := range data { k = gstr.Trim(k, chars) - if !set.Contains(k) { + if !fieldSet.Contains(k) { delete(data, k) } } } else if len(m.fieldsEx) > 0 { // Filter specified fields. for _, v := range m.fieldsEx { - delete(data, v) + delete(data, gconv.String(v)) } } return data, nil diff --git a/util/gutil/gutil.go b/util/gutil/gutil.go index b77d6c2fd..2c5a555d1 100644 --- a/util/gutil/gutil.go +++ b/util/gutil/gutil.go @@ -18,9 +18,9 @@ const ( ) // Keys retrieves and returns the keys from given map or struct. -func Keys(mapOrStruct interface{}) (keysOrAttrs []string) { +func Keys(mapOrStruct any) (keysOrAttrs []string) { keysOrAttrs = make([]string, 0) - if m, ok := mapOrStruct.(map[string]interface{}); ok { + if m, ok := mapOrStruct.(map[string]any); ok { for k := range m { keysOrAttrs = append(keysOrAttrs, k) } @@ -63,6 +63,7 @@ func Keys(mapOrStruct interface{}) (keysOrAttrs []string) { keysOrAttrs = append(keysOrAttrs, fieldType.Name) } } + default: } return } @@ -108,6 +109,7 @@ func Values(mapOrStruct interface{}) (values []interface{}) { values = append(values, reflectValue.Field(i).Interface()) } } + default: } return }