From 9e03cea26926b662bd3bc1d851caef318d1b6ad3 Mon Sep 17 00:00:00 2001 From: guonaihong Date: Wed, 12 Jun 2019 19:13:11 +0800 Subject: [PATCH 1/2] achieve #1898 ```go package main import ( "fmt" "github.com/gin-gonic/gin" ) func main() { r := gin.Default() r.GET("/", func(c *gin.Context) { var i int var err error err = c.DefaultQueryVar("int", &i, -1) fmt.Printf("%v, %v\n", i, err) var ss []string err = c.DefaultQueryVar("slice", &ss, []string{"5", "5", "5"}) fmt.Printf("%v, %v\n", ss, err) var b bool err = c.DefaultQueryVar("bool", &b, false) fmt.Printf("%v, %v\n", b, err) var f float64 err = c.DefaultQueryVar("f", &f, 3.14) fmt.Printf("%v, %v\n", f, err) }) r.Run() return } ``` --- binding/query.go | 134 +++++++++++++++++++++++++++++++++++++++++- binding/query_test.go | 72 +++++++++++++++++++++++ context.go | 59 +++++++++++++++++++ 3 files changed, 264 insertions(+), 1 deletion(-) create mode 100644 binding/query_test.go diff --git a/binding/query.go b/binding/query.go index 219743f2..38f38f76 100644 --- a/binding/query.go +++ b/binding/query.go @@ -4,7 +4,13 @@ package binding -import "net/http" +import ( + "errors" + "fmt" + "net/http" + "reflect" + "time" +) type queryBinding struct{} @@ -19,3 +25,129 @@ func (queryBinding) Bind(req *http.Request, obj interface{}) error { } return validate(obj) } + +var intBitSize = map[reflect.Kind]int{ + reflect.Int: 0, + reflect.Int8: 8, + reflect.Int16: 16, + reflect.Int32: 32, + reflect.Int64: 64, +} + +var uintBitSize = map[reflect.Kind]int{ + reflect.Uint: 0, + reflect.Uint8: 8, + reflect.Uint16: 16, + reflect.Uint32: 32, + reflect.Uint64: 64, +} + +var floatBitSize = map[reflect.Kind]int{ + reflect.Float32: 32, + reflect.Float64: 64, +} + +var durationType = reflect.TypeOf(time.Duration(0)) +var timeType = reflect.TypeOf(time.Time{}) + +func setTime(value string, val reflect.Value) (err error) { + var t time.Time + if t, err = time.ParseInLocation(time.RFC3339, value, time.Local); err != nil { + return err + } + val.Set(reflect.ValueOf(t)) + return +} + +func parseBaseTypeVar(value string, ptr reflect.Value) (err error) { + val := ptr.Elem() + switch val.Kind() { + // bool + case reflect.Bool: + return setBoolField(value, val) + // string + case reflect.String: + val.SetString(value) + return + case reflect.Int64: + if val.Type() == durationType { + return setTimeDuration(value, val, reflect.StructField{}) + } + + case reflect.Struct: + if val.Type() == timeType { + return setTime(value, val) + } + } + + // int, int8, int16, int32, int64 + if bs, ok := intBitSize[val.Kind()]; ok { + return setIntField(value, bs, val) + } + + // uint, uint8, uint16, uint32, uint64 + if bs, ok := uintBitSize[val.Kind()]; ok { + return setUintField(value, bs, val) + } + + // float32 float64 + if bs, ok := floatBitSize[val.Kind()]; ok { + return setFloatField(value, bs, val) + } + + return nil +} + +func setSlice2(values []string, ptr reflect.Value) error { + slice := reflect.MakeSlice(ptr.Elem().Type(), len(values), len(values)) + ptr.Elem().Set(slice) + if err := setArray2(values, ptr); err != nil { + return err + } + + ptr.Elem().Set(slice) + return nil + +} + +func setArray2(values []string, ptr reflect.Value) error { + if ptr.Elem().Len() != len(values) { + return fmt.Errorf("Unequal length:%d:%d", ptr.Elem().Len(), len(values)) + } + + for i, v := range values { + if err := parseBaseTypeVar(v, ptr.Elem().Index(i).Addr()); err != nil { + return err + } + } + return nil +} + +// slice, array +// base type +func parseTypeVar(ptr reflect.Value, values []string) error { + switch ptr.Elem().Kind() { + case reflect.Slice: + return setSlice2(values, ptr) + case reflect.Array: + return setArray2(values, ptr) + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, + reflect.Ptr, reflect.Struct, reflect.UnsafePointer: + + if ptr.Elem().Type() == timeType { + return parseBaseTypeVar(values[0], ptr) + } + return errors.New("Unsupported type") + default: + return parseBaseTypeVar(values[0], ptr) + } +} + +func SetValue(ptr, defaultVal reflect.Value, values []string, ok bool) error { + if ok { + return parseTypeVar(ptr, values) + } + + ptr.Elem().Set(defaultVal) + return nil +} diff --git a/binding/query_test.go b/binding/query_test.go new file mode 100644 index 00000000..0e51ec3c --- /dev/null +++ b/binding/query_test.go @@ -0,0 +1,72 @@ +package binding + +import ( + "reflect" + "testing" + "time" +) + +type typeTest struct { + needParse []string + ptr interface{} + want interface{} +} + +func TestParseTypeVar(t *testing.T) { + var ( + b bool + i int + i8 int8 + i16 int16 + i32 int32 + i64 int64 + u uint + u8 uint8 + u16 uint16 + u32 uint32 + u64 uint64 + s string + f32 float32 + f64 float64 + duration time.Duration + stringSlice []string + intSlice []int + float32Slice []float32 + stringArray [3]string + tm time.Time + ) + + tv := []typeTest{ + {needParse: []string{"1"}, ptr: &i, want: 1}, + {needParse: []string{"2"}, ptr: &i8, want: int8(2)}, + {needParse: []string{"3"}, ptr: &i16, want: int16(3)}, + {needParse: []string{"4"}, ptr: &i32, want: int32(4)}, + {needParse: []string{"5"}, ptr: &i64, want: int64(5)}, + {needParse: []string{"6"}, ptr: &u, want: uint(6)}, + {needParse: []string{"7"}, ptr: &u8, want: uint8(7)}, + {needParse: []string{"8"}, ptr: &u16, want: uint16(8)}, + {needParse: []string{"9"}, ptr: &u32, want: uint32(9)}, + {needParse: []string{"10"}, ptr: &u64, want: uint64(10)}, + {needParse: []string{"test"}, ptr: &s, want: "test"}, + {needParse: []string{"1.1"}, ptr: &f32, want: float32(1.1)}, + {needParse: []string{"2.2"}, ptr: &f64, want: float64(2.2)}, + {needParse: []string{"1", "2", "3"}, ptr: &stringSlice, want: []string{"1", "2", "3"}}, + {needParse: []string{"1", "2", "3"}, ptr: &intSlice, want: []int{1, 2, 3}}, + {needParse: []string{"4.1", "5.1", "6.1"}, ptr: &float32Slice, want: []float32{4.1, 5.1, 6.1}}, + {needParse: []string{"a1", "a2", "a3"}, ptr: &stringArray, want: [3]string{"a1", "a2", "a3"}}, + {needParse: []string{"true"}, ptr: &b, want: true}, + {needParse: []string{"1s"}, ptr: &duration, want: time.Second}, + {needParse: []string{"2006-01-02T15:04:05Z"}, ptr: &tm, want: time.Date(2006, 1, 2, 15, 4, 5, 0, time.UTC)}, + } + + for k := range tv { + if err := parseTypeVar(reflect.ValueOf(tv[k].ptr), tv[k].needParse); err != nil { + t.Errorf("parseBaseTypeVar %T fail:%s\n", tv[k].want, err) + } + + if !reflect.DeepEqual(reflect.ValueOf(tv[k].ptr).Elem().Interface(), tv[k].want) { + t.Errorf("parseBaseTypeVar %T fail got:%v, want:%v\n", tv[k].ptr, tv[k].ptr, tv[k].want) + } + } + +} diff --git a/context.go b/context.go index ffb9a2de..e3ddfef1 100644 --- a/context.go +++ b/context.go @@ -15,6 +15,7 @@ import ( "net/http" "net/url" "os" + "reflect" "strings" "time" @@ -354,6 +355,64 @@ func (c *Context) Query(key string) string { return value } +// GET /?bool=true&int=3&slice=1&slice=2&slice=3 +// var i int +// err = c.QueryVar("int", &i) +// i == 3 + +// var ss []string +// err = c.QueryVar("slice", &ss) +// ss == []string{"1", "2", "3"} + +// var b bool +// err = c.QueryVar("bool", &b) +// b == true + +// var f float64 +// err = c.QueryVar("f", &f) +// f == 0.0 +func (c *Context) QueryVar(key string, val interface{}) error { + rv := reflect.ValueOf(val) + if rv.Kind() != reflect.Ptr || rv.IsNil() { + return errors.New("Invalid parameter") + } + + values, ok := c.GetQueryArray(key) + return binding.SetValue(rv, rv.Elem(), values, ok) +} + +// GET /?bool=true&int=3&slice=1&slice=2&slice=3 +// var i int +// err = c.DefaultQueryVar("int", &i, -1) +// i == 3 + +// var ss []string +// err = c.DefaultQueryVar("slice", &ss, []string{}) +// ss == []string{"1", "2", "3"} + +// var b bool +// err = c.DefaultQueryVar("bool", &b, false) +// b == true + +// var f float64 +// err = c.DefaultQueryVar("f", &f, 3.14) +// f == 3.14 + +func (c *Context) DefaultQueryVar(key string, val interface{}, defaultValue interface{}) error { + rv := reflect.ValueOf(val) + if rv.Kind() != reflect.Ptr || rv.IsNil() { + return errors.New("Invalid parameter") + } + + if rv.Elem().Type() != reflect.TypeOf(defaultValue) { + return fmt.Errorf("type fail: defautValue type is %v: value type is %v:", + reflect.TypeOf(defaultValue), rv.Elem().Type()) + } + + values, ok := c.GetQueryArray(key) + return binding.SetValue(rv, reflect.ValueOf(defaultValue), values, ok) +} + // DefaultQuery returns the keyed url query value if it exists, // otherwise it returns the specified defaultValue string. // See: Query() and GetQuery() for further information. From b4e2ca45851c5a656de7cb77b803a10d7e953f91 Mon Sep 17 00:00:00 2001 From: guonaihong Date: Fri, 14 Jun 2019 19:35:26 +0800 Subject: [PATCH 2/2] achieve #1866 ```go package main import ( "fmt" "github.com/gin-gonic/gin" "time" ) func main() { r := gin.Default() r.GET("/:int/:string/:float/:bool/:duration", func(c *gin.Context) { var i int var err error err = c.ParamVar("int", &i) fmt.Printf("%v, %v\n", i, err) var b bool err = c.ParamVar("bool", &b) fmt.Printf("%v, %v\n", b, err) var f float64 err = c.ParamVar("float", &f) fmt.Printf("%v, %v\n", f, err) var s string err = c.ParamVar("string", &s) fmt.Printf("%v, %v\n", s, err) var d time.Duration err = c.ParamVar("duration", &d) fmt.Printf("%v, %v\n", d, err) }) r.Run() return } // client // curl -X GET 127.0.0.1:8080/1/test/3.14/true/1s ``` --- context.go | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/context.go b/context.go index e3ddfef1..f65f9402 100644 --- a/context.go +++ b/context.go @@ -328,6 +328,31 @@ func (c *Context) GetStringMapStringSlice(key string) (smss map[string][]string) return } +// ParamVar get the value of the URL param. +// curl -X GET 127.0.0.1:8080/1/test/3.14/true/1s +// +// router.GET("/:int/:string/:float/:bool/:duration", func(c *gin.Context) { +// var i int +// var b bool +// var f float64 +// var s string +// var d time.Duration + +// err = c.ParamVar("int", &i) // int == 1 +// err = c.ParamVar("bool", &b) // bool == true +// err = c.ParamVar("float", &f) // float == 3.14 +// err = c.ParamVar("string", &s) // string == test +// err = c.ParamVar("duration", &d) // duration == time.Second +// }) +func (c *Context) ParamVar(key string, val interface{}) error { + rv := reflect.ValueOf(val) + if rv.Kind() != reflect.Ptr || rv.IsNil() { + return errors.New("Invalid parameter") + } + + return binding.SetValue(rv, rv.Elem(), []string{c.Param(key)}, true) +} + /************************************/ /************ INPUT DATA ************/ /************************************/