diff --git a/context.go b/context.go index ba2dfcbd..b7112220 100644 --- a/context.go +++ b/context.go @@ -925,7 +925,8 @@ func (c *Context) Render(code int, r render.Render) { if err := r.Render(c.Writer); err != nil { // if err is net error, pushing error to c.Errors - if _, ok := err.(*net.OpError); ok { + netOpErr := &net.OpError{} + if ok := errors.As(err, &netOpErr); ok { _ = c.Error(err) c.Abort() } else { diff --git a/context_test.go b/context_test.go index 1dec902c..2f0bf345 100644 --- a/context_test.go +++ b/context_test.go @@ -32,7 +32,15 @@ import ( var _ context.Context = (*Context)(nil) -var errTestRender = errors.New("TestRender") +var errTestRender = errors.New("TestPanicRender") + +var errNetOpErr = &net.OpError{ + Op: "testneterr", + Net: "", + Source: nil, + Addr: nil, + Err: nil, +} // Unit tests TODO // func (c *Context) File(filepath string) { @@ -645,21 +653,41 @@ func TestContextBodyAllowedForStatus(t *testing.T) { assert.True(t, true, bodyAllowedForStatus(http.StatusInternalServerError)) } -type TestRender struct{} +type TestPanicRender struct{} -func (*TestRender) Render(http.ResponseWriter) error { +func (*TestPanicRender) Render(http.ResponseWriter) error { return errTestRender } -func (*TestRender) WriteContentType(http.ResponseWriter) {} +func (*TestPanicRender) WriteContentType(http.ResponseWriter) {} + +func TestContextRenderPanicIfErr(t *testing.T) { + defer func() { + r := recover() + assert.Equal(t, fmt.Sprint(errTestRender), fmt.Sprint(r)) + }() -func TestContextRenderIfErr(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) - c.Render(http.StatusOK, &TestRender{}) + c.Render(http.StatusOK, &TestPanicRender{}) +} - assert.Equal(t, errorMsgs{&Error{Err: errTestRender, Type: 1}}, c.Errors) +type TestNetErrorRender struct{} + +func (*TestNetErrorRender) Render(http.ResponseWriter) error { + return errNetOpErr +} + +func (*TestNetErrorRender) WriteContentType(http.ResponseWriter) {} + +func TestContextRenderIfNetErr(t *testing.T) { + w := httptest.NewRecorder() + c, _ := CreateTestContext(w) + + c.Render(http.StatusOK, &TestNetErrorRender{}) + + assert.Equal(t, errorMsgs{&Error{Err: errNetOpErr, Type: 1}}, c.Errors) } // Tests that the response is serialized as JSON