refactor(context): refactor Keys type to map[any]any (#3963)

* refactor(context): refactor keys to `map[any]any`

Signed-off-by: Flc゛ <four_leaf_clover@foxmail.com>

* refactor(context): refactor keys to `map[any]any`

Signed-off-by: Flc゛ <four_leaf_clover@foxmail.com>

* style(context): remove empty lines before GetInt16, GetIntSlice, and GetStringMapString methods

- Remove unnecessary empty lines in the context.go file
- Improve code readability and consistency

Signed-off-by: flc1125 <four_leaf_clover@foxmail.com>

* refactor(context): simplify GetStringSlice function

- Replace manual type assertion with generic getTyped function
- Reduce code duplication and improve type safety

Signed-off-by: flc1125 <four_leaf_clover@foxmail.com>

* test(context): improve context.Set and context.Get tests

- Split existing test into separate functions for different scenarios
- Add test for setting and getting values with any key type
- Add test for handling non-comparable keys
- Improve assertions to check for key existence and value correctness

Signed-off-by: flc1125 <four_leaf_clover@foxmail.com>

* refactor(context): replace fmt.Errorf with fmt.Sprintf in panic message

* test(context): remove trailing hyphen from context_test.go

* refactor(context): improve error message for missing key in context

- Remove unnecessary quotes around the key in the error message
- Simplify the error message format for better readability

* test(context): improve panic test message for non-existent key

---------

Signed-off-by: Flc゛ <four_leaf_clover@foxmail.com>
Signed-off-by: flc1125 <four_leaf_clover@foxmail.com>
This commit is contained in:
Flc゛ 2025-05-26 23:15:14 +08:00 committed by GitHub
parent 848e1cdd0d
commit 41d8591eb1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 83 additions and 43 deletions

View File

@ -6,6 +6,7 @@ package gin
import ( import (
"errors" "errors"
"fmt"
"io" "io"
"io/fs" "io/fs"
"log" "log"
@ -72,7 +73,7 @@ type Context struct {
mu sync.RWMutex mu sync.RWMutex
// Keys is a key/value pair exclusively for the context of each request. // Keys is a key/value pair exclusively for the context of each request.
Keys map[string]any Keys map[any]any
// Errors is a list of errors attached to all the handlers/middlewares who used this context. // Errors is a list of errors attached to all the handlers/middlewares who used this context.
Errors errorMsgs Errors errorMsgs
@ -129,7 +130,7 @@ func (c *Context) Copy() *Context {
cp.fullPath = c.fullPath cp.fullPath = c.fullPath
cKeys := c.Keys cKeys := c.Keys
cp.Keys = make(map[string]any, len(cKeys)) cp.Keys = make(map[any]any, len(cKeys))
c.mu.RLock() c.mu.RLock()
for k, v := range cKeys { for k, v := range cKeys {
cp.Keys[k] = v cp.Keys[k] = v
@ -264,11 +265,11 @@ func (c *Context) Error(err error) *Error {
// Set is used to store a new key/value pair exclusively for this context. // Set is used to store a new key/value pair exclusively for this context.
// It also lazy initializes c.Keys if it was not used previously. // It also lazy initializes c.Keys if it was not used previously.
func (c *Context) Set(key string, value any) { func (c *Context) Set(key any, value any) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
if c.Keys == nil { if c.Keys == nil {
c.Keys = make(map[string]any) c.Keys = make(map[any]any)
} }
c.Keys[key] = value c.Keys[key] = value
@ -276,7 +277,7 @@ func (c *Context) Set(key string, value any) {
// Get returns the value for the given key, ie: (value, true). // Get returns the value for the given key, ie: (value, true).
// If the value does not exist it returns (nil, false) // If the value does not exist it returns (nil, false)
func (c *Context) Get(key string) (value any, exists bool) { func (c *Context) Get(key any) (value any, exists bool) {
c.mu.RLock() c.mu.RLock()
defer c.mu.RUnlock() defer c.mu.RUnlock()
value, exists = c.Keys[key] value, exists = c.Keys[key]
@ -284,14 +285,14 @@ func (c *Context) Get(key string) (value any, exists bool) {
} }
// MustGet returns the value for the given key if it exists, otherwise it panics. // MustGet returns the value for the given key if it exists, otherwise it panics.
func (c *Context) MustGet(key string) any { func (c *Context) MustGet(key any) any {
if value, exists := c.Get(key); exists { if value, exists := c.Get(key); exists {
return value return value
} }
panic("Key \"" + key + "\" does not exist") panic(fmt.Sprintf("key %v does not exist", key))
} }
func getTyped[T any](c *Context, key string) (res T) { func getTyped[T any](c *Context, key any) (res T) {
if val, ok := c.Get(key); ok && val != nil { if val, ok := c.Get(key); ok && val != nil {
res, _ = val.(T) res, _ = val.(T)
} }
@ -299,162 +300,162 @@ func getTyped[T any](c *Context, key string) (res T) {
} }
// GetString returns the value associated with the key as a string. // GetString returns the value associated with the key as a string.
func (c *Context) GetString(key string) (s string) { func (c *Context) GetString(key any) (s string) {
return getTyped[string](c, key) return getTyped[string](c, key)
} }
// GetBool returns the value associated with the key as a boolean. // GetBool returns the value associated with the key as a boolean.
func (c *Context) GetBool(key string) (b bool) { func (c *Context) GetBool(key any) (b bool) {
return getTyped[bool](c, key) return getTyped[bool](c, key)
} }
// GetInt returns the value associated with the key as an integer. // GetInt returns the value associated with the key as an integer.
func (c *Context) GetInt(key string) (i int) { func (c *Context) GetInt(key any) (i int) {
return getTyped[int](c, key) return getTyped[int](c, key)
} }
// GetInt8 returns the value associated with the key as an integer 8. // GetInt8 returns the value associated with the key as an integer 8.
func (c *Context) GetInt8(key string) (i8 int8) { func (c *Context) GetInt8(key any) (i8 int8) {
return getTyped[int8](c, key) return getTyped[int8](c, key)
} }
// GetInt16 returns the value associated with the key as an integer 16. // GetInt16 returns the value associated with the key as an integer 16.
func (c *Context) GetInt16(key string) (i16 int16) { func (c *Context) GetInt16(key any) (i16 int16) {
return getTyped[int16](c, key) return getTyped[int16](c, key)
} }
// GetInt32 returns the value associated with the key as an integer 32. // GetInt32 returns the value associated with the key as an integer 32.
func (c *Context) GetInt32(key string) (i32 int32) { func (c *Context) GetInt32(key any) (i32 int32) {
return getTyped[int32](c, key) return getTyped[int32](c, key)
} }
// GetInt64 returns the value associated with the key as an integer 64. // GetInt64 returns the value associated with the key as an integer 64.
func (c *Context) GetInt64(key string) (i64 int64) { func (c *Context) GetInt64(key any) (i64 int64) {
return getTyped[int64](c, key) return getTyped[int64](c, key)
} }
// GetUint returns the value associated with the key as an unsigned integer. // GetUint returns the value associated with the key as an unsigned integer.
func (c *Context) GetUint(key string) (ui uint) { func (c *Context) GetUint(key any) (ui uint) {
return getTyped[uint](c, key) return getTyped[uint](c, key)
} }
// GetUint8 returns the value associated with the key as an unsigned integer 8. // GetUint8 returns the value associated with the key as an unsigned integer 8.
func (c *Context) GetUint8(key string) (ui8 uint8) { func (c *Context) GetUint8(key any) (ui8 uint8) {
return getTyped[uint8](c, key) return getTyped[uint8](c, key)
} }
// GetUint16 returns the value associated with the key as an unsigned integer 16. // GetUint16 returns the value associated with the key as an unsigned integer 16.
func (c *Context) GetUint16(key string) (ui16 uint16) { func (c *Context) GetUint16(key any) (ui16 uint16) {
return getTyped[uint16](c, key) return getTyped[uint16](c, key)
} }
// GetUint32 returns the value associated with the key as an unsigned integer 32. // GetUint32 returns the value associated with the key as an unsigned integer 32.
func (c *Context) GetUint32(key string) (ui32 uint32) { func (c *Context) GetUint32(key any) (ui32 uint32) {
return getTyped[uint32](c, key) return getTyped[uint32](c, key)
} }
// GetUint64 returns the value associated with the key as an unsigned integer 64. // GetUint64 returns the value associated with the key as an unsigned integer 64.
func (c *Context) GetUint64(key string) (ui64 uint64) { func (c *Context) GetUint64(key any) (ui64 uint64) {
return getTyped[uint64](c, key) return getTyped[uint64](c, key)
} }
// GetFloat32 returns the value associated with the key as a float32. // GetFloat32 returns the value associated with the key as a float32.
func (c *Context) GetFloat32(key string) (f32 float32) { func (c *Context) GetFloat32(key any) (f32 float32) {
return getTyped[float32](c, key) return getTyped[float32](c, key)
} }
// GetFloat64 returns the value associated with the key as a float64. // GetFloat64 returns the value associated with the key as a float64.
func (c *Context) GetFloat64(key string) (f64 float64) { func (c *Context) GetFloat64(key any) (f64 float64) {
return getTyped[float64](c, key) return getTyped[float64](c, key)
} }
// GetTime returns the value associated with the key as time. // GetTime returns the value associated with the key as time.
func (c *Context) GetTime(key string) (t time.Time) { func (c *Context) GetTime(key any) (t time.Time) {
return getTyped[time.Time](c, key) return getTyped[time.Time](c, key)
} }
// GetDuration returns the value associated with the key as a duration. // GetDuration returns the value associated with the key as a duration.
func (c *Context) GetDuration(key string) (d time.Duration) { func (c *Context) GetDuration(key any) (d time.Duration) {
return getTyped[time.Duration](c, key) return getTyped[time.Duration](c, key)
} }
// GetIntSlice returns the value associated with the key as a slice of integers. // GetIntSlice returns the value associated with the key as a slice of integers.
func (c *Context) GetIntSlice(key string) (is []int) { func (c *Context) GetIntSlice(key any) (is []int) {
return getTyped[[]int](c, key) return getTyped[[]int](c, key)
} }
// GetInt8Slice returns the value associated with the key as a slice of int8 integers. // GetInt8Slice returns the value associated with the key as a slice of int8 integers.
func (c *Context) GetInt8Slice(key string) (i8s []int8) { func (c *Context) GetInt8Slice(key any) (i8s []int8) {
return getTyped[[]int8](c, key) return getTyped[[]int8](c, key)
} }
// GetInt16Slice returns the value associated with the key as a slice of int16 integers. // GetInt16Slice returns the value associated with the key as a slice of int16 integers.
func (c *Context) GetInt16Slice(key string) (i16s []int16) { func (c *Context) GetInt16Slice(key any) (i16s []int16) {
return getTyped[[]int16](c, key) return getTyped[[]int16](c, key)
} }
// GetInt32Slice returns the value associated with the key as a slice of int32 integers. // GetInt32Slice returns the value associated with the key as a slice of int32 integers.
func (c *Context) GetInt32Slice(key string) (i32s []int32) { func (c *Context) GetInt32Slice(key any) (i32s []int32) {
return getTyped[[]int32](c, key) return getTyped[[]int32](c, key)
} }
// GetInt64Slice returns the value associated with the key as a slice of int64 integers. // GetInt64Slice returns the value associated with the key as a slice of int64 integers.
func (c *Context) GetInt64Slice(key string) (i64s []int64) { func (c *Context) GetInt64Slice(key any) (i64s []int64) {
return getTyped[[]int64](c, key) return getTyped[[]int64](c, key)
} }
// GetUintSlice returns the value associated with the key as a slice of unsigned integers. // GetUintSlice returns the value associated with the key as a slice of unsigned integers.
func (c *Context) GetUintSlice(key string) (uis []uint) { func (c *Context) GetUintSlice(key any) (uis []uint) {
return getTyped[[]uint](c, key) return getTyped[[]uint](c, key)
} }
// GetUint8Slice returns the value associated with the key as a slice of uint8 integers. // GetUint8Slice returns the value associated with the key as a slice of uint8 integers.
func (c *Context) GetUint8Slice(key string) (ui8s []uint8) { func (c *Context) GetUint8Slice(key any) (ui8s []uint8) {
return getTyped[[]uint8](c, key) return getTyped[[]uint8](c, key)
} }
// GetUint16Slice returns the value associated with the key as a slice of uint16 integers. // GetUint16Slice returns the value associated with the key as a slice of uint16 integers.
func (c *Context) GetUint16Slice(key string) (ui16s []uint16) { func (c *Context) GetUint16Slice(key any) (ui16s []uint16) {
return getTyped[[]uint16](c, key) return getTyped[[]uint16](c, key)
} }
// GetUint32Slice returns the value associated with the key as a slice of uint32 integers. // GetUint32Slice returns the value associated with the key as a slice of uint32 integers.
func (c *Context) GetUint32Slice(key string) (ui32s []uint32) { func (c *Context) GetUint32Slice(key any) (ui32s []uint32) {
return getTyped[[]uint32](c, key) return getTyped[[]uint32](c, key)
} }
// GetUint64Slice returns the value associated with the key as a slice of uint64 integers. // GetUint64Slice returns the value associated with the key as a slice of uint64 integers.
func (c *Context) GetUint64Slice(key string) (ui64s []uint64) { func (c *Context) GetUint64Slice(key any) (ui64s []uint64) {
return getTyped[[]uint64](c, key) return getTyped[[]uint64](c, key)
} }
// GetFloat32Slice returns the value associated with the key as a slice of float32 numbers. // GetFloat32Slice returns the value associated with the key as a slice of float32 numbers.
func (c *Context) GetFloat32Slice(key string) (f32s []float32) { func (c *Context) GetFloat32Slice(key any) (f32s []float32) {
return getTyped[[]float32](c, key) return getTyped[[]float32](c, key)
} }
// GetFloat64Slice returns the value associated with the key as a slice of float64 numbers. // GetFloat64Slice returns the value associated with the key as a slice of float64 numbers.
func (c *Context) GetFloat64Slice(key string) (f64s []float64) { func (c *Context) GetFloat64Slice(key any) (f64s []float64) {
return getTyped[[]float64](c, key) return getTyped[[]float64](c, key)
} }
// GetStringSlice returns the value associated with the key as a slice of strings. // GetStringSlice returns the value associated with the key as a slice of strings.
func (c *Context) GetStringSlice(key string) (ss []string) { func (c *Context) GetStringSlice(key any) (ss []string) {
return getTyped[[]string](c, key) return getTyped[[]string](c, key)
} }
// GetStringMap returns the value associated with the key as a map of interfaces. // GetStringMap returns the value associated with the key as a map of interfaces.
func (c *Context) GetStringMap(key string) (sm map[string]any) { func (c *Context) GetStringMap(key any) (sm map[string]any) {
return getTyped[map[string]any](c, key) return getTyped[map[string]any](c, key)
} }
// GetStringMapString returns the value associated with the key as a map of strings. // GetStringMapString returns the value associated with the key as a map of strings.
func (c *Context) GetStringMapString(key string) (sms map[string]string) { func (c *Context) GetStringMapString(key any) (sms map[string]string) {
return getTyped[map[string]string](c, key) return getTyped[map[string]string](c, key)
} }
// GetStringMapStringSlice returns the value associated with the key as a map to a slice of strings. // GetStringMapStringSlice returns the value associated with the key as a map to a slice of strings.
func (c *Context) GetStringMapStringSlice(key string) (smss map[string][]string) { func (c *Context) GetStringMapStringSlice(key any) (smss map[string][]string) {
return getTyped[map[string][]string](c, key) return getTyped[map[string][]string](c, key)
} }

View File

@ -257,7 +257,46 @@ func TestContextSetGet(t *testing.T) {
assert.False(t, err) assert.False(t, err)
assert.Equal(t, "bar", c.MustGet("foo")) assert.Equal(t, "bar", c.MustGet("foo"))
assert.Panics(t, func() { c.MustGet("no_exist") }) assert.Panicsf(t, func() {
c.MustGet("no_exist")
}, "key no_exist does not exist")
}
func TestContextSetGetAnyKey(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder())
type key struct{}
tests := []struct {
key any
}{
{1},
{int32(1)},
{int64(1)},
{uint(1)},
{float32(1)},
{key{}},
{&key{}},
}
for _, tt := range tests {
t.Run(reflect.TypeOf(tt.key).String(), func(t *testing.T) {
c.Set(tt.key, 1)
value, ok := c.Get(tt.key)
assert.True(t, ok)
assert.Equal(t, 1, value)
})
}
}
func TestContextSetGetPanicsWhenKeyNotComparable(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder())
assert.Panics(t, func() {
c.Set([]int{1}, 1)
c.Set(func() {}, 1)
c.Set(make(chan int), 1)
})
} }
func TestContextSetGetValues(t *testing.T) { func TestContextSetGetValues(t *testing.T) {

View File

@ -82,7 +82,7 @@ type LogFormatterParams struct {
// BodySize is the size of the Response Body // BodySize is the size of the Response Body
BodySize int BodySize int
// Keys are the keys set on the request's context. // Keys are the keys set on the request's context.
Keys map[string]any Keys map[any]any
} }
// StatusCodeColor is the ANSI color for appropriately logging http status code to a terminal. // StatusCodeColor is the ANSI color for appropriately logging http status code to a terminal.

View File

@ -181,7 +181,7 @@ func TestLoggerWithFormatter(t *testing.T) {
func TestLoggerWithConfigFormatting(t *testing.T) { func TestLoggerWithConfigFormatting(t *testing.T) {
var gotParam LogFormatterParams var gotParam LogFormatterParams
var gotKeys map[string]any var gotKeys map[any]any
buffer := new(strings.Builder) buffer := new(strings.Builder)
router := New() router := New()