Compare commits

...

4 Commits

Author SHA1 Message Date
Abdulrahman Tayara
ce6b1b0697
Merge 14bae4aef34259de3f1341a952f6b795759b10c0 into 3ab698dc5110af1977d57226e4995c57dd34c233 2026-01-17 17:03:25 +05:00
OHZEKI Naoki
3ab698dc51
refactor(recovery): smart error comparison (#4142)
* refactor(recovery): rename var in CustomRecoveryWithWriter

* refactor(recovery): smart error comparison

* test(recovery): Directly reference the syscall error string
2026-01-17 16:40:43 +08:00
abdulrahman
14bae4aef3 feature: add ability to handle the binder errors 2024-01-01 12:52:53 +04:00
abdulrahman
0e0ea62a94 feature: implement request auto binder utility 2023-12-29 02:39:19 +04:00
4 changed files with 280 additions and 28 deletions

114
auto_binder.go Normal file
View File

@ -0,0 +1,114 @@
package gin
import (
"errors"
"fmt"
"reflect"
)
var (
defaultAutoBinderErrorHandler = func(ctx *Context, err error) {
ctx.Error(err)
ctx.Abort()
}
)
type binderType func(obj any) error
func isFunc(obj any) bool {
return reflect.TypeOf(obj).Kind() == reflect.Func
}
func isGinContext(rt reflect.Type) bool {
return rt == reflect.TypeOf((*Context)(nil))
}
func isPtr(rt reflect.Type) bool {
return rt.Kind() == reflect.Pointer
}
func isStruct(rt reflect.Type) bool {
return rt.Kind() == reflect.Struct
}
func constructStruct(prt reflect.Type, binder binderType) (reflect.Value, error) {
var pInstancePtr any
if isPtr(prt) {
pInstancePtr = reflect.New(prt.Elem()).Interface()
} else {
pInstancePtr = reflect.New(prt).Interface()
}
if err := binder(pInstancePtr); err != nil {
return reflect.Value{}, err
}
if prt.Kind() == reflect.Pointer {
return reflect.ValueOf(pInstancePtr), nil
}
return reflect.ValueOf(pInstancePtr).Elem(), nil
}
func callHandler(rt reflect.Type, rv reflect.Value, ctx *Context, binder binderType) error {
numberOfParams := rt.NumIn()
var args []reflect.Value
for i := 0; i < numberOfParams; i++ {
prt := rt.In(i)
if isGinContext(prt) {
args = append(args, reflect.ValueOf(ctx))
continue
}
if isStruct(prt) || isStruct(prt.Elem()) {
if prv, err := constructStruct(prt, binder); err != nil {
return err
} else {
args = append(args, prv)
}
}
}
rv.Call(args)
return nil
}
// AutoBinder is a handler wrapper that binds the actual handler's request.
//
// Example: func MyGetHandler(ctx *gin.Context, request *MyRequest) {}
//
// engine.GET("/endpoint", gin.AutoBinder(MyGetHandler)) and you can handel the errors by passing a handler
//
// engine.GET("/endpoint", gin.AutoBinder(MyGetHandler, func(ctx *gin.Context, err error) {}))
func AutoBinder(handler any, errorHandler ...func(*Context, error)) HandlerFunc {
rt := reflect.TypeOf(handler)
if rt.Kind() != reflect.Func {
panic(errors.New("invalid handler type"))
}
if rt.NumIn() == 0 {
panic(fmt.Errorf("handler should have at least one parameter, handler: %v", rt.Name()))
}
return func(ctx *Context) {
selectedErrorHandler := defaultAutoBinderErrorHandler
if len(errorHandler) > 0 && errorHandler[0] != nil {
selectedErrorHandler = errorHandler[0]
}
rt := reflect.TypeOf(handler)
rv := reflect.ValueOf(handler)
if err := callHandler(rt, rv, ctx, func(obj any) error {
return ctx.ShouldBind(obj)
}); err != nil {
selectedErrorHandler(ctx, err)
}
}
}

144
auto_binder_test.go Normal file
View File

@ -0,0 +1,144 @@
package gin
import (
"encoding/json"
"reflect"
"testing"
"github.com/stretchr/testify/assert"
)
type myRequest struct {
Field1 string `json:"field_1"`
}
func TestAutoBinder_isFunc(t *testing.T) {
tests := []struct {
name string
input any
want bool
}{
{
"valid function",
func(string, int) error { return nil },
true,
},
{
"valid zero-param function",
func() error { return nil },
true,
},
{
"invalid function",
func() string { return "" }(),
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual := isFunc(tt.input)
assert.Equal(t, tt.want, actual)
})
}
}
func TestAutoBinder_isGinContext(t *testing.T) {
assert.True(t, isGinContext(reflect.TypeOf(&Context{})))
assert.False(t, isGinContext(reflect.TypeOf(Context{})))
assert.False(t, isGinContext(reflect.TypeOf([]string{})))
}
func TestAutoBinder_constructStruct_pointer(t *testing.T) {
type myType struct {
Field int `json:"field"`
}
rv, err := constructStruct(reflect.TypeOf(&myType{}), func(obj any) error {
assert.True(t, isPtr(reflect.TypeOf(obj)))
return json.Unmarshal(
[]byte(`{"field": 10}`),
obj,
)
})
assert.NoError(t, err)
instance, ok := rv.Interface().(*myType)
assert.True(t, ok)
assert.Equal(t, 10, instance.Field)
}
func TestAutoBinder_constructStruct_nonPointer(t *testing.T) {
type myType struct {
Field int `json:"field"`
}
rv, err := constructStruct(reflect.TypeOf(myType{}), func(obj any) error {
assert.True(t, isPtr(reflect.TypeOf(obj)))
return json.Unmarshal(
[]byte(`{"field": 10}`),
obj,
)
})
assert.NoError(t, err)
instance, ok := rv.Interface().(myType)
assert.True(t, ok)
assert.Equal(t, 10, instance.Field)
}
func TestAutoBinder_constructStruct_nonStruct(t *testing.T) {
_, err := constructStruct(reflect.TypeOf("string test"), func(obj any) error {
assert.True(t, isPtr(reflect.TypeOf(obj)))
return json.Unmarshal(
[]byte(`{"field": 10}`),
obj,
)
})
assert.Error(t, err)
}
func TestAutoBinder_callHandler(t *testing.T) {
called := false
handler := func(ctx *Context, req *myRequest) {
if ctx == nil {
t.Errorf("ctx should not passed as nil")
return
}
if req.Field1 != "value1" {
t.Errorf("expected %v, actual %v", "value1", req.Field1)
}
called = true
}
rt := reflect.TypeOf(handler)
rv := reflect.ValueOf(handler)
ctx := &Context{}
err := callHandler(rt, rv, ctx, func(obj any) error {
return json.Unmarshal([]byte(`{"field_1": "value1"}`), obj)
})
if err != nil {
panic(err)
}
if !called {
t.Error("handler should be called")
}
}

View File

@ -12,12 +12,12 @@ import (
"fmt"
"io"
"log"
"net"
"net/http"
"net/http/httputil"
"os"
"runtime"
"strings"
"syscall"
"time"
"github.com/gin-gonic/gin/internal/bytesconv"
@ -57,40 +57,33 @@ func CustomRecoveryWithWriter(out io.Writer, handle RecoveryFunc) HandlerFunc {
}
return func(c *Context) {
defer func() {
if err := recover(); err != nil {
if rec := recover(); rec != nil {
// Check for a broken connection, as it is not really a
// condition that warrants a panic stack trace.
var brokenPipe bool
if ne, ok := err.(*net.OpError); ok {
var se *os.SyscallError
if errors.As(ne, &se) {
seStr := strings.ToLower(se.Error())
if strings.Contains(seStr, "broken pipe") ||
strings.Contains(seStr, "connection reset by peer") {
brokenPipe = true
}
}
}
if e, ok := err.(error); ok && errors.Is(e, http.ErrAbortHandler) {
brokenPipe = true
var isBrokenPipe bool
err, ok := rec.(error)
if ok {
isBrokenPipe = errors.Is(err, syscall.EPIPE) ||
errors.Is(err, syscall.ECONNRESET) ||
errors.Is(err, http.ErrAbortHandler)
}
if logger != nil {
if brokenPipe {
logger.Printf("%s\n%s%s", err, secureRequestDump(c.Request), reset)
if isBrokenPipe {
logger.Printf("%s\n%s%s", rec, secureRequestDump(c.Request), reset)
} else if IsDebugging() {
logger.Printf("[Recovery] %s panic recovered:\n%s\n%s\n%s%s",
timeFormat(time.Now()), secureRequestDump(c.Request), err, stack(stackSkip), reset)
timeFormat(time.Now()), secureRequestDump(c.Request), rec, stack(stackSkip), reset)
} else {
logger.Printf("[Recovery] %s panic recovered:\n%s\n%s%s",
timeFormat(time.Now()), err, stack(stackSkip), reset)
timeFormat(time.Now()), rec, stack(stackSkip), reset)
}
}
if brokenPipe {
if isBrokenPipe {
// If the connection is dead, we can't write a status to it.
c.Error(err.(error)) //nolint: errcheck
c.Error(err) //nolint: errcheck
c.Abort()
} else {
handle(c, err)
handle(c, rec)
}
}
}()

View File

@ -98,13 +98,13 @@ func TestFunction(t *testing.T) {
func TestPanicWithBrokenPipe(t *testing.T) {
const expectCode = 204
expectMsgs := map[syscall.Errno]string{
syscall.EPIPE: "broken pipe",
syscall.ECONNRESET: "connection reset by peer",
expectErrnos := []syscall.Errno{
syscall.EPIPE,
syscall.ECONNRESET,
}
for errno, expectMsg := range expectMsgs {
t.Run(expectMsg, func(t *testing.T) {
for _, errno := range expectErrnos {
t.Run("Recovery from "+errno.Error(), func(t *testing.T) {
var buf strings.Builder
router := New()
@ -122,7 +122,8 @@ func TestPanicWithBrokenPipe(t *testing.T) {
w := PerformRequest(router, http.MethodGet, "/recovery")
// TEST
assert.Equal(t, expectCode, w.Code)
assert.Contains(t, strings.ToLower(buf.String()), expectMsg)
assert.Contains(t, strings.ToLower(buf.String()), errno.Error())
assert.NotContains(t, strings.ToLower(buf.String()), "[Recovery]")
})
}
}