From 41d8591eb16bf23732de9ae2b699d6cae54c2ae8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Flc=E3=82=9B?= Date: Mon, 26 May 2025 23:15:14 +0800 Subject: [PATCH] refactor(context): refactor `Keys` type to `map[any]any` (#3963) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor(context): refactor keys to `map[any]any` Signed-off-by: Flc゛ * refactor(context): refactor keys to `map[any]any` Signed-off-by: Flc゛ * style(context): remove empty lines before GetInt16, GetIntSlice, and GetStringMapString methods - Remove unnecessary empty lines in the context.go file - Improve code readability and consistency Signed-off-by: flc1125 * refactor(context): simplify GetStringSlice function - Replace manual type assertion with generic getTyped function - Reduce code duplication and improve type safety Signed-off-by: flc1125 * test(context): improve context.Set and context.Get tests - Split existing test into separate functions for different scenarios - Add test for setting and getting values with any key type - Add test for handling non-comparable keys - Improve assertions to check for key existence and value correctness Signed-off-by: flc1125 * refactor(context): replace fmt.Errorf with fmt.Sprintf in panic message * test(context): remove trailing hyphen from context_test.go * refactor(context): improve error message for missing key in context - Remove unnecessary quotes around the key in the error message - Simplify the error message format for better readability * test(context): improve panic test message for non-existent key --------- Signed-off-by: Flc゛ Signed-off-by: flc1125 --- context.go | 81 +++++++++++++++++++++++++------------------------ context_test.go | 41 ++++++++++++++++++++++++- logger.go | 2 +- logger_test.go | 2 +- 4 files changed, 83 insertions(+), 43 deletions(-) diff --git a/context.go b/context.go index 3ebb3ee4..bf12830c 100644 --- a/context.go +++ b/context.go @@ -6,6 +6,7 @@ package gin import ( "errors" + "fmt" "io" "io/fs" "log" @@ -72,7 +73,7 @@ type Context struct { mu sync.RWMutex // Keys is a key/value pair exclusively for the context of each request. - Keys map[string]any + Keys map[any]any // Errors is a list of errors attached to all the handlers/middlewares who used this context. Errors errorMsgs @@ -129,7 +130,7 @@ func (c *Context) Copy() *Context { cp.fullPath = c.fullPath cKeys := c.Keys - cp.Keys = make(map[string]any, len(cKeys)) + cp.Keys = make(map[any]any, len(cKeys)) c.mu.RLock() for k, v := range cKeys { cp.Keys[k] = v @@ -264,11 +265,11 @@ func (c *Context) Error(err error) *Error { // Set is used to store a new key/value pair exclusively for this context. // It also lazy initializes c.Keys if it was not used previously. -func (c *Context) Set(key string, value any) { +func (c *Context) Set(key any, value any) { c.mu.Lock() defer c.mu.Unlock() if c.Keys == nil { - c.Keys = make(map[string]any) + c.Keys = make(map[any]any) } c.Keys[key] = value @@ -276,7 +277,7 @@ func (c *Context) Set(key string, value any) { // Get returns the value for the given key, ie: (value, true). // If the value does not exist it returns (nil, false) -func (c *Context) Get(key string) (value any, exists bool) { +func (c *Context) Get(key any) (value any, exists bool) { c.mu.RLock() defer c.mu.RUnlock() value, exists = c.Keys[key] @@ -284,14 +285,14 @@ func (c *Context) Get(key string) (value any, exists bool) { } // MustGet returns the value for the given key if it exists, otherwise it panics. -func (c *Context) MustGet(key string) any { +func (c *Context) MustGet(key any) any { if value, exists := c.Get(key); exists { return value } - panic("Key \"" + key + "\" does not exist") + panic(fmt.Sprintf("key %v does not exist", key)) } -func getTyped[T any](c *Context, key string) (res T) { +func getTyped[T any](c *Context, key any) (res T) { if val, ok := c.Get(key); ok && val != nil { res, _ = val.(T) } @@ -299,162 +300,162 @@ func getTyped[T any](c *Context, key string) (res T) { } // GetString returns the value associated with the key as a string. -func (c *Context) GetString(key string) (s string) { +func (c *Context) GetString(key any) (s string) { return getTyped[string](c, key) } // GetBool returns the value associated with the key as a boolean. -func (c *Context) GetBool(key string) (b bool) { +func (c *Context) GetBool(key any) (b bool) { return getTyped[bool](c, key) } // GetInt returns the value associated with the key as an integer. -func (c *Context) GetInt(key string) (i int) { +func (c *Context) GetInt(key any) (i int) { return getTyped[int](c, key) } // GetInt8 returns the value associated with the key as an integer 8. -func (c *Context) GetInt8(key string) (i8 int8) { +func (c *Context) GetInt8(key any) (i8 int8) { return getTyped[int8](c, key) } // GetInt16 returns the value associated with the key as an integer 16. -func (c *Context) GetInt16(key string) (i16 int16) { +func (c *Context) GetInt16(key any) (i16 int16) { return getTyped[int16](c, key) } // GetInt32 returns the value associated with the key as an integer 32. -func (c *Context) GetInt32(key string) (i32 int32) { +func (c *Context) GetInt32(key any) (i32 int32) { return getTyped[int32](c, key) } // GetInt64 returns the value associated with the key as an integer 64. -func (c *Context) GetInt64(key string) (i64 int64) { +func (c *Context) GetInt64(key any) (i64 int64) { return getTyped[int64](c, key) } // GetUint returns the value associated with the key as an unsigned integer. -func (c *Context) GetUint(key string) (ui uint) { +func (c *Context) GetUint(key any) (ui uint) { return getTyped[uint](c, key) } // GetUint8 returns the value associated with the key as an unsigned integer 8. -func (c *Context) GetUint8(key string) (ui8 uint8) { +func (c *Context) GetUint8(key any) (ui8 uint8) { return getTyped[uint8](c, key) } // GetUint16 returns the value associated with the key as an unsigned integer 16. -func (c *Context) GetUint16(key string) (ui16 uint16) { +func (c *Context) GetUint16(key any) (ui16 uint16) { return getTyped[uint16](c, key) } // GetUint32 returns the value associated with the key as an unsigned integer 32. -func (c *Context) GetUint32(key string) (ui32 uint32) { +func (c *Context) GetUint32(key any) (ui32 uint32) { return getTyped[uint32](c, key) } // GetUint64 returns the value associated with the key as an unsigned integer 64. -func (c *Context) GetUint64(key string) (ui64 uint64) { +func (c *Context) GetUint64(key any) (ui64 uint64) { return getTyped[uint64](c, key) } // GetFloat32 returns the value associated with the key as a float32. -func (c *Context) GetFloat32(key string) (f32 float32) { +func (c *Context) GetFloat32(key any) (f32 float32) { return getTyped[float32](c, key) } // GetFloat64 returns the value associated with the key as a float64. -func (c *Context) GetFloat64(key string) (f64 float64) { +func (c *Context) GetFloat64(key any) (f64 float64) { return getTyped[float64](c, key) } // GetTime returns the value associated with the key as time. -func (c *Context) GetTime(key string) (t time.Time) { +func (c *Context) GetTime(key any) (t time.Time) { return getTyped[time.Time](c, key) } // GetDuration returns the value associated with the key as a duration. -func (c *Context) GetDuration(key string) (d time.Duration) { +func (c *Context) GetDuration(key any) (d time.Duration) { return getTyped[time.Duration](c, key) } // GetIntSlice returns the value associated with the key as a slice of integers. -func (c *Context) GetIntSlice(key string) (is []int) { +func (c *Context) GetIntSlice(key any) (is []int) { return getTyped[[]int](c, key) } // GetInt8Slice returns the value associated with the key as a slice of int8 integers. -func (c *Context) GetInt8Slice(key string) (i8s []int8) { +func (c *Context) GetInt8Slice(key any) (i8s []int8) { return getTyped[[]int8](c, key) } // GetInt16Slice returns the value associated with the key as a slice of int16 integers. -func (c *Context) GetInt16Slice(key string) (i16s []int16) { +func (c *Context) GetInt16Slice(key any) (i16s []int16) { return getTyped[[]int16](c, key) } // GetInt32Slice returns the value associated with the key as a slice of int32 integers. -func (c *Context) GetInt32Slice(key string) (i32s []int32) { +func (c *Context) GetInt32Slice(key any) (i32s []int32) { return getTyped[[]int32](c, key) } // GetInt64Slice returns the value associated with the key as a slice of int64 integers. -func (c *Context) GetInt64Slice(key string) (i64s []int64) { +func (c *Context) GetInt64Slice(key any) (i64s []int64) { return getTyped[[]int64](c, key) } // GetUintSlice returns the value associated with the key as a slice of unsigned integers. -func (c *Context) GetUintSlice(key string) (uis []uint) { +func (c *Context) GetUintSlice(key any) (uis []uint) { return getTyped[[]uint](c, key) } // GetUint8Slice returns the value associated with the key as a slice of uint8 integers. -func (c *Context) GetUint8Slice(key string) (ui8s []uint8) { +func (c *Context) GetUint8Slice(key any) (ui8s []uint8) { return getTyped[[]uint8](c, key) } // GetUint16Slice returns the value associated with the key as a slice of uint16 integers. -func (c *Context) GetUint16Slice(key string) (ui16s []uint16) { +func (c *Context) GetUint16Slice(key any) (ui16s []uint16) { return getTyped[[]uint16](c, key) } // GetUint32Slice returns the value associated with the key as a slice of uint32 integers. -func (c *Context) GetUint32Slice(key string) (ui32s []uint32) { +func (c *Context) GetUint32Slice(key any) (ui32s []uint32) { return getTyped[[]uint32](c, key) } // GetUint64Slice returns the value associated with the key as a slice of uint64 integers. -func (c *Context) GetUint64Slice(key string) (ui64s []uint64) { +func (c *Context) GetUint64Slice(key any) (ui64s []uint64) { return getTyped[[]uint64](c, key) } // GetFloat32Slice returns the value associated with the key as a slice of float32 numbers. -func (c *Context) GetFloat32Slice(key string) (f32s []float32) { +func (c *Context) GetFloat32Slice(key any) (f32s []float32) { return getTyped[[]float32](c, key) } // GetFloat64Slice returns the value associated with the key as a slice of float64 numbers. -func (c *Context) GetFloat64Slice(key string) (f64s []float64) { +func (c *Context) GetFloat64Slice(key any) (f64s []float64) { return getTyped[[]float64](c, key) } // GetStringSlice returns the value associated with the key as a slice of strings. -func (c *Context) GetStringSlice(key string) (ss []string) { +func (c *Context) GetStringSlice(key any) (ss []string) { return getTyped[[]string](c, key) } // GetStringMap returns the value associated with the key as a map of interfaces. -func (c *Context) GetStringMap(key string) (sm map[string]any) { +func (c *Context) GetStringMap(key any) (sm map[string]any) { return getTyped[map[string]any](c, key) } // GetStringMapString returns the value associated with the key as a map of strings. -func (c *Context) GetStringMapString(key string) (sms map[string]string) { +func (c *Context) GetStringMapString(key any) (sms map[string]string) { return getTyped[map[string]string](c, key) } // GetStringMapStringSlice returns the value associated with the key as a map to a slice of strings. -func (c *Context) GetStringMapStringSlice(key string) (smss map[string][]string) { +func (c *Context) GetStringMapStringSlice(key any) (smss map[string][]string) { return getTyped[map[string][]string](c, key) } diff --git a/context_test.go b/context_test.go index c8559505..ff43cd0a 100644 --- a/context_test.go +++ b/context_test.go @@ -257,7 +257,46 @@ func TestContextSetGet(t *testing.T) { assert.False(t, err) assert.Equal(t, "bar", c.MustGet("foo")) - assert.Panics(t, func() { c.MustGet("no_exist") }) + assert.Panicsf(t, func() { + c.MustGet("no_exist") + }, "key no_exist does not exist") +} + +func TestContextSetGetAnyKey(t *testing.T) { + c, _ := CreateTestContext(httptest.NewRecorder()) + + type key struct{} + + tests := []struct { + key any + }{ + {1}, + {int32(1)}, + {int64(1)}, + {uint(1)}, + {float32(1)}, + {key{}}, + {&key{}}, + } + + for _, tt := range tests { + t.Run(reflect.TypeOf(tt.key).String(), func(t *testing.T) { + c.Set(tt.key, 1) + value, ok := c.Get(tt.key) + assert.True(t, ok) + assert.Equal(t, 1, value) + }) + } +} + +func TestContextSetGetPanicsWhenKeyNotComparable(t *testing.T) { + c, _ := CreateTestContext(httptest.NewRecorder()) + + assert.Panics(t, func() { + c.Set([]int{1}, 1) + c.Set(func() {}, 1) + c.Set(make(chan int), 1) + }) } func TestContextSetGetValues(t *testing.T) { diff --git a/logger.go b/logger.go index db2c6832..f4a250ac 100644 --- a/logger.go +++ b/logger.go @@ -82,7 +82,7 @@ type LogFormatterParams struct { // BodySize is the size of the Response Body BodySize int // Keys are the keys set on the request's context. - Keys map[string]any + Keys map[any]any } // StatusCodeColor is the ANSI color for appropriately logging http status code to a terminal. diff --git a/logger_test.go b/logger_test.go index 30b25290..8a542e97 100644 --- a/logger_test.go +++ b/logger_test.go @@ -181,7 +181,7 @@ func TestLoggerWithFormatter(t *testing.T) { func TestLoggerWithConfigFormatting(t *testing.T) { var gotParam LogFormatterParams - var gotKeys map[string]any + var gotKeys map[any]any buffer := new(strings.Builder) router := New()