diff --git a/recovery.go b/recovery.go index bbf1d565..cbe126ee 100644 --- a/recovery.go +++ b/recovery.go @@ -42,6 +42,7 @@ func CustomRecovery(handle RecoveryFunc) HandlerFunc { } // RecoveryWithWriter returns a middleware for a given writer that recovers from any panics and writes a 500 if there was one. +// If recovery handlers are provided, only the first one is used. func RecoveryWithWriter(out io.Writer, recovery ...RecoveryFunc) HandlerFunc { if len(recovery) > 0 { return CustomRecoveryWithWriter(out, recovery[0]) diff --git a/recovery_test.go b/recovery_test.go index 028c4ad6..192c7352 100644 --- a/recovery_test.go +++ b/recovery_test.go @@ -256,6 +256,33 @@ func TestRecoveryWithWriterWithCustomRecovery(t *testing.T) { SetMode(TestMode) } +func TestRecoveryWithWriterUsesOnlyFirstRecoveryFunc(t *testing.T) { + buffer := new(strings.Builder) + router := New() + + calls := 0 + first := func(c *Context, err any) { + calls++ + assert.Equal(t, "Oops, Houston, we have a problem", err) + c.AbortWithStatus(http.StatusBadRequest) + } + second := func(c *Context, err any) { + calls += 100 + c.AbortWithStatus(http.StatusTeapot) + } + + router.Use(RecoveryWithWriter(buffer, first, second)) + router.GET("/recovery", func(_ *Context) { + panic("Oops, Houston, we have a problem") + }) + + w := PerformRequest(router, http.MethodGet, "/recovery") + + assert.Equal(t, http.StatusBadRequest, w.Code) + assert.Equal(t, 1, calls) + assert.Contains(t, buffer.String(), "panic recovered") +} + func TestSecureRequestDump(t *testing.T) { tests := []struct { name string