From b9ca6156501206dd767e5d82266a41733330cd47 Mon Sep 17 00:00:00 2001 From: v-caomk Date: Mon, 10 Jul 2023 00:03:26 +0800 Subject: [PATCH] #3660 return error when can not convert a c.Get(key) value by using GetXxxxx() --- context.go | 118 +++++++++++++++++++++++++++++++++++++----------- context_test.go | 51 +++++++++++++-------- 2 files changed, 124 insertions(+), 45 deletions(-) diff --git a/context.go b/context.go index 420ff167..613804d1 100644 --- a/context.go +++ b/context.go @@ -15,6 +15,7 @@ import ( "net/url" "os" "path/filepath" + "reflect" "strings" "sync" "time" @@ -274,105 +275,170 @@ func (c *Context) MustGet(key string) any { } // GetString returns the value associated with the key as a string. -func (c *Context) GetString(key string) (s string) { +func (c *Context) GetString(key string) (s string, err error) { if val, ok := c.Get(key); ok && val != nil { - s, _ = val.(string) + s, ok := val.(string) + if ok { + return s, nil + } else { + return "", errors.New("can not convert type: " + reflect.TypeOf(val).String() + "to string") + } } return } // GetBool returns the value associated with the key as a boolean. -func (c *Context) GetBool(key string) (b bool) { +func (c *Context) GetBool(key string) (b bool, err error) { if val, ok := c.Get(key); ok && val != nil { - b, _ = val.(bool) + b, ok := val.(bool) + if ok { + return b, nil + } else { + return false, errors.New("can not convert type: " + reflect.TypeOf(val).String() + "to bool") + } } return } // GetInt returns the value associated with the key as an integer. -func (c *Context) GetInt(key string) (i int) { +func (c *Context) GetInt(key string) (i int, err error) { if val, ok := c.Get(key); ok && val != nil { - i, _ = val.(int) + i, ok := val.(int) + if ok { + return i, nil + } else { + return 0, errors.New("can not convert type: " + reflect.TypeOf(val).String() + "to int") + } } return } // GetInt64 returns the value associated with the key as an integer. -func (c *Context) GetInt64(key string) (i64 int64) { +func (c *Context) GetInt64(key string) (i64 int64, err error) { if val, ok := c.Get(key); ok && val != nil { - i64, _ = val.(int64) + i64, ok := val.(int64) + if ok { + return i64, nil + } else { + return 0, errors.New("can not convert type: " + reflect.TypeOf(val).String() + "to int64") + } } return } // GetUint returns the value associated with the key as an unsigned integer. -func (c *Context) GetUint(key string) (ui uint) { +func (c *Context) GetUint(key string) (ui uint, err error) { if val, ok := c.Get(key); ok && val != nil { - ui, _ = val.(uint) + ui, ok := val.(uint) + if ok { + return ui, nil + } else { + return 0, errors.New("can not convert type: " + reflect.TypeOf(val).String() + "to uint") + } } return } // GetUint64 returns the value associated with the key as an unsigned integer. -func (c *Context) GetUint64(key string) (ui64 uint64) { +func (c *Context) GetUint64(key string) (ui64 uint64, err error) { if val, ok := c.Get(key); ok && val != nil { - ui64, _ = val.(uint64) + ui64, ok := val.(uint64) + if ok { + return ui64, nil + } else { + return 0, errors.New("can not convert type: " + reflect.TypeOf(val).String() + "to uint64") + } } return } // GetFloat64 returns the value associated with the key as a float64. -func (c *Context) GetFloat64(key string) (f64 float64) { +func (c *Context) GetFloat64(key string) (f64 float64, err error) { if val, ok := c.Get(key); ok && val != nil { - f64, _ = val.(float64) + f64, ok := val.(float64) + if ok { + return f64, nil + } else { + return 0, errors.New("can not convert type: " + reflect.TypeOf(val).String() + "to float64") + } } return } // GetTime returns the value associated with the key as time. -func (c *Context) GetTime(key string) (t time.Time) { +func (c *Context) GetTime(key string) (t time.Time, err error) { if val, ok := c.Get(key); ok && val != nil { - t, _ = val.(time.Time) + t, ok := val.(time.Time) + if ok { + return t, nil + } else { + return time.Time{}, errors.New("can not convert type: " + reflect.TypeOf(val).String() + "to time.Time") + } } return } // GetDuration returns the value associated with the key as a duration. -func (c *Context) GetDuration(key string) (d time.Duration) { +func (c *Context) GetDuration(key string) (d time.Duration, err error) { if val, ok := c.Get(key); ok && val != nil { - d, _ = val.(time.Duration) + d, ok := val.(time.Duration) + if ok { + return d, nil + } else { + return 0, errors.New("can not convert type: " + reflect.TypeOf(val).String() + "to time.Duration") + } } return } // GetStringSlice returns the value associated with the key as a slice of strings. -func (c *Context) GetStringSlice(key string) (ss []string) { +func (c *Context) GetStringSlice(key string) (ss []string, err error) { if val, ok := c.Get(key); ok && val != nil { - ss, _ = val.([]string) + ss, ok := val.([]string) + if ok { + return ss, nil + } else { + return []string{}, errors.New("can not convert type: " + reflect.TypeOf(val).String() + "to []string") + } } return } // GetStringMap returns the value associated with the key as a map of interfaces. -func (c *Context) GetStringMap(key string) (sm map[string]any) { +func (c *Context) GetStringMap(key string) (sm map[string]any, err error) { if val, ok := c.Get(key); ok && val != nil { - sm, _ = val.(map[string]any) + sm, ok := val.(map[string]any) + if ok { + return sm, nil + } else { + return map[string]any{}, errors.New("can not convert type: " + reflect.TypeOf(val).String() + "to map[string]any") + } } return } // GetStringMapString returns the value associated with the key as a map of strings. -func (c *Context) GetStringMapString(key string) (sms map[string]string) { +func (c *Context) GetStringMapString(key string) (sms map[string]string, err error) { if val, ok := c.Get(key); ok && val != nil { - sms, _ = val.(map[string]string) + sms, ok := val.(map[string]string) + if ok { + return sms, nil + } else { + return map[string]string{}, errors.New("can not convert type: " + reflect.TypeOf(val).String() + "to map[string]string") + } } return } // GetStringMapStringSlice returns the value associated with the key as a map to a slice of strings. -func (c *Context) GetStringMapStringSlice(key string) (smss map[string][]string) { +func (c *Context) GetStringMapStringSlice(key string) (smss map[string][]string, err error) { if val, ok := c.Get(key); ok && val != nil { - smss, _ = val.(map[string][]string) + smss, ok := val.(map[string][]string) + if ok { + return smss, nil + } else { + return map[string][]string{}, errors.New("can not convert type: " + reflect.TypeOf(val).String() + "to map[string][]string") + } } return } diff --git a/context_test.go b/context_test.go index 70d47583..79038001 100644 --- a/context_test.go +++ b/context_test.go @@ -229,62 +229,72 @@ func TestContextSetGetValues(t *testing.T) { func TestContextGetString(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) c.Set("string", "this is a string") - assert.Equal(t, "this is a string", c.GetString("string")) + s, _ := c.GetString("string") + assert.Equal(t, "this is a string", s) } func TestContextSetGetBool(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) c.Set("bool", true) - assert.True(t, c.GetBool("bool")) + b, _ := c.GetBool("bool") + assert.True(t, b) } func TestContextGetInt(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) c.Set("int", 1) - assert.Equal(t, 1, c.GetInt("int")) + i, _ := c.GetInt("int") + assert.Equal(t, 1, i) } func TestContextGetInt64(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) c.Set("int64", int64(42424242424242)) - assert.Equal(t, int64(42424242424242), c.GetInt64("int64")) + i64, _ := c.GetInt64("int64") + assert.Equal(t, int64(42424242424242), i64) } func TestContextGetUint(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) c.Set("uint", uint(1)) - assert.Equal(t, uint(1), c.GetUint("uint")) + ui, _ := c.GetUint("uint") + assert.Equal(t, uint(1), ui) } func TestContextGetUint64(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) c.Set("uint64", uint64(18446744073709551615)) - assert.Equal(t, uint64(18446744073709551615), c.GetUint64("uint64")) + ui64, _ := c.GetUint64("uint64") + assert.Equal(t, uint64(18446744073709551615), ui64) } func TestContextGetFloat64(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) c.Set("float64", 4.2) - assert.Equal(t, 4.2, c.GetFloat64("float64")) + f64, _ := c.GetFloat64("float64") + assert.Equal(t, 4.2, f64) } func TestContextGetTime(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) t1, _ := time.Parse("1/2/2006 15:04:05", "01/01/2017 12:00:00") c.Set("time", t1) - assert.Equal(t, t1, c.GetTime("time")) + getTime, _ := c.GetTime("time") + assert.Equal(t, t1, getTime) } func TestContextGetDuration(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) c.Set("duration", time.Second) - assert.Equal(t, time.Second, c.GetDuration("duration")) + d, _ := c.GetDuration("duration") + assert.Equal(t, time.Second, d) } func TestContextGetStringSlice(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) c.Set("slice", []string{"foo"}) - assert.Equal(t, []string{"foo"}, c.GetStringSlice("slice")) + ss, _ := c.GetStringSlice("slice") + assert.Equal(t, []string{"foo"}, ss) } func TestContextGetStringMap(t *testing.T) { @@ -292,9 +302,10 @@ func TestContextGetStringMap(t *testing.T) { m := make(map[string]any) m["foo"] = 1 c.Set("map", m) - - assert.Equal(t, m, c.GetStringMap("map")) - assert.Equal(t, 1, c.GetStringMap("map")["foo"]) + sm, _ := c.GetStringMap("map") + assert.Equal(t, m, sm) + stringMap, _ := c.GetStringMap("map") + assert.Equal(t, 1, stringMap["foo"]) } func TestContextGetStringMapString(t *testing.T) { @@ -302,9 +313,10 @@ func TestContextGetStringMapString(t *testing.T) { m := make(map[string]string) m["foo"] = "bar" c.Set("map", m) - - assert.Equal(t, m, c.GetStringMapString("map")) - assert.Equal(t, "bar", c.GetStringMapString("map")["foo"]) + sms, _ := c.GetStringMapString("map") + assert.Equal(t, m, sms) + mapString, _ := c.GetStringMapString("map") + assert.Equal(t, "bar", mapString["foo"]) } func TestContextGetStringMapStringSlice(t *testing.T) { @@ -312,9 +324,10 @@ func TestContextGetStringMapStringSlice(t *testing.T) { m := make(map[string][]string) m["foo"] = []string{"foo"} c.Set("map", m) - - assert.Equal(t, m, c.GetStringMapStringSlice("map")) - assert.Equal(t, []string{"foo"}, c.GetStringMapStringSlice("map")["foo"]) + smss, _ := c.GetStringMapStringSlice("map") + assert.Equal(t, m, smss) + slice, _ := c.GetStringMapStringSlice("map") + assert.Equal(t, []string{"foo"}, slice["foo"]) } func TestContextCopy(t *testing.T) {