From 0e0ea62a94e4fe0b5574ad4277f4ba5505ead85c Mon Sep 17 00:00:00 2001 From: abdulrahman Date: Fri, 29 Dec 2023 02:39:19 +0400 Subject: [PATCH 1/2] feature: implement request auto binder utility --- auto_binder.go | 100 ++++++++++++++++++++++++++++++ auto_binder_test.go | 144 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 244 insertions(+) create mode 100644 auto_binder.go create mode 100644 auto_binder_test.go diff --git a/auto_binder.go b/auto_binder.go new file mode 100644 index 00000000..233546f1 --- /dev/null +++ b/auto_binder.go @@ -0,0 +1,100 @@ +package gin + +import ( + "errors" + "fmt" + "reflect" +) + +type binderType func(obj any) error + +func isFunc(obj any) bool { + return reflect.TypeOf(obj).Kind() == reflect.Func +} + +func isGinContext(rt reflect.Type) bool { + return rt == reflect.TypeOf((*Context)(nil)) +} + +func isPtr(rt reflect.Type) bool { + return rt.Kind() == reflect.Pointer +} + +func isStruct(rt reflect.Type) bool { + return rt.Kind() == reflect.Struct +} + +func constructStruct(prt reflect.Type, binder binderType) (reflect.Value, error) { + var pInstancePtr any + + if isPtr(prt) { + pInstancePtr = reflect.New(prt.Elem()).Interface() + } else { + pInstancePtr = reflect.New(prt).Interface() + } + + if err := binder(pInstancePtr); err != nil { + return reflect.Value{}, err + } + + if prt.Kind() == reflect.Pointer { + return reflect.ValueOf(pInstancePtr), nil + } + + return reflect.ValueOf(pInstancePtr).Elem(), nil +} + +func callHandler(rt reflect.Type, rv reflect.Value, ctx *Context, binder binderType) error { + numberOfParams := rt.NumIn() + + var args []reflect.Value + + for i := 0; i < numberOfParams; i++ { + prt := rt.In(i) + + if isGinContext(prt) { + args = append(args, reflect.ValueOf(ctx)) + continue + } + + if isStruct(prt) || isStruct(prt.Elem()) { + if prv, err := constructStruct(prt, binder); err != nil { + return err + } else { + args = append(args, prv) + } + } + } + + rv.Call(args) + + return nil +} + +// AutoBinder is a handler wrapper that binds the actual handler's request. +// +// Example: func MyGetHandler(ctx *gin.Context, request *MyRequest) {} +// +// engine.GET("/endpoint", gin.AutoBinder(MyGetHandler)) +func AutoBinder(handler any) HandlerFunc { + rt := reflect.TypeOf(handler) + + if rt.Kind() != reflect.Func { + panic(errors.New("invalid handler type")) + } + + if rt.NumIn() == 0 { + panic(fmt.Errorf("handler should have at least one parameter, handler: %v", rt.Name())) + } + + return func(ctx *Context) { + rt := reflect.TypeOf(handler) + rv := reflect.ValueOf(handler) + + if err := callHandler(rt, rv, ctx, func(obj any) error { + return ctx.ShouldBind(obj) + }); err != nil { + ctx.Error(err) + } + } +} diff --git a/auto_binder_test.go b/auto_binder_test.go new file mode 100644 index 00000000..c20e66c3 --- /dev/null +++ b/auto_binder_test.go @@ -0,0 +1,144 @@ +package gin + +import ( + "encoding/json" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +type myRequest struct { + Field1 string `json:"field_1"` +} + +func TestAutoBinder_isFunc(t *testing.T) { + tests := []struct { + name string + input any + want bool + }{ + { + "valid function", + func(string, int) error { return nil }, + true, + }, + { + "valid zero-param function", + func() error { return nil }, + true, + }, + { + "invalid function", + func() string { return "" }(), + false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + actual := isFunc(tt.input) + assert.Equal(t, tt.want, actual) + }) + } +} + +func TestAutoBinder_isGinContext(t *testing.T) { + assert.True(t, isGinContext(reflect.TypeOf(&Context{}))) + assert.False(t, isGinContext(reflect.TypeOf(Context{}))) + assert.False(t, isGinContext(reflect.TypeOf([]string{}))) +} + +func TestAutoBinder_constructStruct_pointer(t *testing.T) { + type myType struct { + Field int `json:"field"` + } + + rv, err := constructStruct(reflect.TypeOf(&myType{}), func(obj any) error { + assert.True(t, isPtr(reflect.TypeOf(obj))) + + return json.Unmarshal( + []byte(`{"field": 10}`), + obj, + ) + }) + + assert.NoError(t, err) + + instance, ok := rv.Interface().(*myType) + + assert.True(t, ok) + + assert.Equal(t, 10, instance.Field) +} + +func TestAutoBinder_constructStruct_nonPointer(t *testing.T) { + type myType struct { + Field int `json:"field"` + } + + rv, err := constructStruct(reflect.TypeOf(myType{}), func(obj any) error { + assert.True(t, isPtr(reflect.TypeOf(obj))) + + return json.Unmarshal( + []byte(`{"field": 10}`), + obj, + ) + }) + + assert.NoError(t, err) + + instance, ok := rv.Interface().(myType) + + assert.True(t, ok) + + assert.Equal(t, 10, instance.Field) +} + +func TestAutoBinder_constructStruct_nonStruct(t *testing.T) { + _, err := constructStruct(reflect.TypeOf("string test"), func(obj any) error { + assert.True(t, isPtr(reflect.TypeOf(obj))) + + return json.Unmarshal( + []byte(`{"field": 10}`), + obj, + ) + }) + + assert.Error(t, err) +} + +func TestAutoBinder_callHandler(t *testing.T) { + called := false + + handler := func(ctx *Context, req *myRequest) { + if ctx == nil { + t.Errorf("ctx should not passed as nil") + return + } + + if req.Field1 != "value1" { + t.Errorf("expected %v, actual %v", "value1", req.Field1) + } + + called = true + } + + rt := reflect.TypeOf(handler) + rv := reflect.ValueOf(handler) + + ctx := &Context{} + + err := callHandler(rt, rv, ctx, func(obj any) error { + return json.Unmarshal([]byte(`{"field_1": "value1"}`), obj) + }) + + if err != nil { + panic(err) + } + + if !called { + t.Error("handler should be called") + } + +} From 14bae4aef34259de3f1341a952f6b795759b10c0 Mon Sep 17 00:00:00 2001 From: abdulrahman Date: Mon, 1 Jan 2024 12:52:53 +0400 Subject: [PATCH 2/2] feature: add ability to handle the binder errors --- auto_binder.go | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/auto_binder.go b/auto_binder.go index 233546f1..374f26fd 100644 --- a/auto_binder.go +++ b/auto_binder.go @@ -6,6 +6,13 @@ import ( "reflect" ) +var ( + defaultAutoBinderErrorHandler = func(ctx *Context, err error) { + ctx.Error(err) + ctx.Abort() + } +) + type binderType func(obj any) error func isFunc(obj any) bool { @@ -75,8 +82,10 @@ func callHandler(rt reflect.Type, rv reflect.Value, ctx *Context, binder binderT // // Example: func MyGetHandler(ctx *gin.Context, request *MyRequest) {} // -// engine.GET("/endpoint", gin.AutoBinder(MyGetHandler)) -func AutoBinder(handler any) HandlerFunc { +// engine.GET("/endpoint", gin.AutoBinder(MyGetHandler)) and you can handel the errors by passing a handler +// +// engine.GET("/endpoint", gin.AutoBinder(MyGetHandler, func(ctx *gin.Context, err error) {})) +func AutoBinder(handler any, errorHandler ...func(*Context, error)) HandlerFunc { rt := reflect.TypeOf(handler) if rt.Kind() != reflect.Func { @@ -88,13 +97,18 @@ func AutoBinder(handler any) HandlerFunc { } return func(ctx *Context) { + selectedErrorHandler := defaultAutoBinderErrorHandler + if len(errorHandler) > 0 && errorHandler[0] != nil { + selectedErrorHandler = errorHandler[0] + } + rt := reflect.TypeOf(handler) rv := reflect.ValueOf(handler) if err := callHandler(rt, rv, ctx, func(obj any) error { return ctx.ShouldBind(obj) }); err != nil { - ctx.Error(err) + selectedErrorHandler(ctx, err) } } }