// Copyright 2014 Manu Martinez-Almeida. All rights reserved. // Use of this source code is governed by a MIT style // license that can be found in the LICENSE file. package gin import ( "net" "net/http" "os" "strings" "syscall" "testing" "github.com/stretchr/testify/assert" ) func TestPanicClean(t *testing.T) { buffer := new(strings.Builder) router := New() password := "my-super-secret-password" router.Use(RecoveryWithWriter(buffer)) router.GET("/recovery", func(c *Context) { c.AbortWithStatus(http.StatusBadRequest) panic("Oupps, Houston, we have a problem") }) // RUN w := PerformRequest(router, http.MethodGet, "/recovery", header{ Key: "Host", Value: "www.google.com", }, header{ Key: "Authorization", Value: "Bearer " + password, }, header{ Key: "Content-Type", Value: "application/json", }, ) // TEST assert.Equal(t, http.StatusBadRequest, w.Code) // Check the buffer does not have the secret key assert.NotContains(t, buffer.String(), password) } // TestPanicInHandler assert that panic has been recovered. func TestPanicInHandler(t *testing.T) { buffer := new(strings.Builder) router := New() router.Use(RecoveryWithWriter(buffer)) router.GET("/recovery", func(_ *Context) { panic("Oupps, Houston, we have a problem") }) // RUN w := PerformRequest(router, http.MethodGet, "/recovery") // TEST assert.Equal(t, http.StatusInternalServerError, w.Code) assert.Contains(t, buffer.String(), "panic recovered") assert.Contains(t, buffer.String(), "Oupps, Houston, we have a problem") assert.Contains(t, buffer.String(), t.Name()) assert.NotContains(t, buffer.String(), "GET /recovery") // Debug mode prints the request SetMode(DebugMode) // RUN w = PerformRequest(router, http.MethodGet, "/recovery") // TEST assert.Equal(t, http.StatusInternalServerError, w.Code) assert.Contains(t, buffer.String(), "GET /recovery") SetMode(TestMode) } // TestPanicWithAbort assert that panic has been recovered even if context.Abort was used. func TestPanicWithAbort(t *testing.T) { router := New() router.Use(RecoveryWithWriter(nil)) router.GET("/recovery", func(c *Context) { c.AbortWithStatus(http.StatusBadRequest) panic("Oupps, Houston, we have a problem") }) // RUN w := PerformRequest(router, http.MethodGet, "/recovery") // TEST assert.Equal(t, http.StatusBadRequest, w.Code) } func TestSource(t *testing.T) { bs := source(nil, 0) assert.Equal(t, dunnoBytes, bs) in := [][]byte{ []byte("Hello world."), []byte("Hi, gin.."), } bs = source(in, 10) assert.Equal(t, dunnoBytes, bs) bs = source(in, 1) assert.Equal(t, []byte("Hello world."), bs) } func TestFunction(t *testing.T) { bs := function(1) assert.Equal(t, dunno, bs) } // TestPanicWithBrokenPipe asserts that recovery specifically handles // writing responses to broken pipes func TestPanicWithBrokenPipe(t *testing.T) { const expectCode = 204 expectMsgs := map[syscall.Errno]string{ syscall.EPIPE: "broken pipe", syscall.ECONNRESET: "connection reset by peer", } for errno, expectMsg := range expectMsgs { t.Run(expectMsg, func(t *testing.T) { var buf strings.Builder router := New() router.Use(RecoveryWithWriter(&buf)) router.GET("/recovery", func(c *Context) { // Start writing response c.Header("X-Test", "Value") c.Status(expectCode) // Oops. Client connection closed e := &net.OpError{Err: &os.SyscallError{Err: errno}} panic(e) }) // RUN w := PerformRequest(router, http.MethodGet, "/recovery") // TEST assert.Equal(t, expectCode, w.Code) assert.Contains(t, strings.ToLower(buf.String()), expectMsg) }) } } func TestCustomRecoveryWithWriter(t *testing.T) { errBuffer := new(strings.Builder) buffer := new(strings.Builder) router := New() handleRecovery := func(c *Context, err any) { errBuffer.WriteString(err.(string)) c.AbortWithStatus(http.StatusBadRequest) } router.Use(CustomRecoveryWithWriter(buffer, handleRecovery)) router.GET("/recovery", func(_ *Context) { panic("Oupps, Houston, we have a problem") }) // RUN w := PerformRequest(router, http.MethodGet, "/recovery") // TEST assert.Equal(t, http.StatusBadRequest, w.Code) assert.Contains(t, buffer.String(), "panic recovered") assert.Contains(t, buffer.String(), "Oupps, Houston, we have a problem") assert.Contains(t, buffer.String(), t.Name()) assert.NotContains(t, buffer.String(), "GET /recovery") // Debug mode prints the request SetMode(DebugMode) // RUN w = PerformRequest(router, http.MethodGet, "/recovery") // TEST assert.Equal(t, http.StatusBadRequest, w.Code) assert.Contains(t, buffer.String(), "GET /recovery") assert.Equal(t, strings.Repeat("Oupps, Houston, we have a problem", 2), errBuffer.String()) SetMode(TestMode) } func TestCustomRecovery(t *testing.T) { errBuffer := new(strings.Builder) buffer := new(strings.Builder) router := New() DefaultErrorWriter = buffer handleRecovery := func(c *Context, err any) { errBuffer.WriteString(err.(string)) c.AbortWithStatus(http.StatusBadRequest) } router.Use(CustomRecovery(handleRecovery)) router.GET("/recovery", func(_ *Context) { panic("Oupps, Houston, we have a problem") }) // RUN w := PerformRequest(router, http.MethodGet, "/recovery") // TEST assert.Equal(t, http.StatusBadRequest, w.Code) assert.Contains(t, buffer.String(), "panic recovered") assert.Contains(t, buffer.String(), "Oupps, Houston, we have a problem") assert.Contains(t, buffer.String(), t.Name()) assert.NotContains(t, buffer.String(), "GET /recovery") // Debug mode prints the request SetMode(DebugMode) // RUN w = PerformRequest(router, http.MethodGet, "/recovery") // TEST assert.Equal(t, http.StatusBadRequest, w.Code) assert.Contains(t, buffer.String(), "GET /recovery") assert.Equal(t, strings.Repeat("Oupps, Houston, we have a problem", 2), errBuffer.String()) SetMode(TestMode) } func TestRecoveryWithWriterWithCustomRecovery(t *testing.T) { errBuffer := new(strings.Builder) buffer := new(strings.Builder) router := New() DefaultErrorWriter = buffer handleRecovery := func(c *Context, err any) { errBuffer.WriteString(err.(string)) c.AbortWithStatus(http.StatusBadRequest) } router.Use(RecoveryWithWriter(DefaultErrorWriter, handleRecovery)) router.GET("/recovery", func(_ *Context) { panic("Oupps, Houston, we have a problem") }) // RUN w := PerformRequest(router, http.MethodGet, "/recovery") // TEST assert.Equal(t, http.StatusBadRequest, w.Code) assert.Contains(t, buffer.String(), "panic recovered") assert.Contains(t, buffer.String(), "Oupps, Houston, we have a problem") assert.Contains(t, buffer.String(), t.Name()) assert.NotContains(t, buffer.String(), "GET /recovery") // Debug mode prints the request SetMode(DebugMode) // RUN w = PerformRequest(router, http.MethodGet, "/recovery") // TEST assert.Equal(t, http.StatusBadRequest, w.Code) assert.Contains(t, buffer.String(), "GET /recovery") assert.Equal(t, strings.Repeat("Oupps, Houston, we have a problem", 2), errBuffer.String()) SetMode(TestMode) } func TestSecureRequestDump(t *testing.T) { tests := []struct { name string req *http.Request wantContains string wantNotContain string }{ { name: "Authorization header standard case", req: func() *http.Request { r, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) r.Header.Set("Authorization", "Bearer secret-token") return r }(), wantContains: "Authorization: *", wantNotContain: "Bearer secret-token", }, { name: "authorization header lowercase", req: func() *http.Request { r, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) r.Header.Set("authorization", "some-secret") return r }(), wantContains: "Authorization: *", wantNotContain: "some-secret", }, { name: "Authorization header mixed case", req: func() *http.Request { r, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) r.Header.Set("AuThOrIzAtIoN", "token123") return r }(), wantContains: "Authorization: *", wantNotContain: "token123", }, { name: "No Authorization header", req: func() *http.Request { r, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) r.Header.Set("Content-Type", "application/json") return r }(), wantContains: "", wantNotContain: "Authorization: *", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := secureRequestDump(tt.req) if tt.wantContains != "" && !strings.Contains(result, tt.wantContains) { t.Errorf("maskHeaders() = %q, want contains %q", result, tt.wantContains) } if tt.wantNotContain != "" && strings.Contains(result, tt.wantNotContain) { t.Errorf("maskHeaders() = %q, want NOT contain %q", result, tt.wantNotContain) } }) } }