Merge branch 'master' into utils-coverage

This commit is contained in:
Bo-Yi Wu 2018-08-12 21:18:11 +08:00 committed by GitHub
commit 08fc478446
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 69 additions and 59 deletions

View File

@ -1560,6 +1560,26 @@ func TestContextRenderDataFromReader(t *testing.T) {
assert.Equal(t, extraHeaders["Content-Disposition"], w.HeaderMap.Get("Content-Disposition")) assert.Equal(t, extraHeaders["Content-Disposition"], w.HeaderMap.Get("Content-Disposition"))
} }
type TestResponseRecorder struct {
*httptest.ResponseRecorder
closeChannel chan bool
}
func (r *TestResponseRecorder) CloseNotify() <-chan bool {
return r.closeChannel
}
func (r *TestResponseRecorder) closeClient() {
r.closeChannel <- true
}
func CreateTestResponseRecorder() *TestResponseRecorder {
return &TestResponseRecorder{
httptest.NewRecorder(),
make(chan bool, 1),
}
}
func TestContextStream(t *testing.T) { func TestContextStream(t *testing.T) {
w := CreateTestResponseRecorder() w := CreateTestResponseRecorder()
c, _ := CreateTestContext(w) c, _ := CreateTestContext(w)

23
gin.go
View File

@ -349,7 +349,9 @@ func (engine *Engine) handleHTTPRequest(c *Context) {
// Find root of the tree for the given HTTP method // Find root of the tree for the given HTTP method
t := engine.trees t := engine.trees
for i, tl := 0, len(t); i < tl; i++ { for i, tl := 0, len(t); i < tl; i++ {
if t[i].method == httpMethod { if t[i].method != httpMethod {
continue
}
root := t[i].root root := t[i].root
// Find route in tree // Find route in tree
handlers, params, tsr := root.getValue(path, c.Params, unescape) handlers, params, tsr := root.getValue(path, c.Params, unescape)
@ -371,11 +373,12 @@ func (engine *Engine) handleHTTPRequest(c *Context) {
} }
break break
} }
}
if engine.HandleMethodNotAllowed { if engine.HandleMethodNotAllowed {
for _, tree := range engine.trees { for _, tree := range engine.trees {
if tree.method != httpMethod { if tree.method == httpMethod {
continue
}
if handlers, _, _ := tree.root.getValue(path, nil, unescape); handlers != nil { if handlers, _, _ := tree.root.getValue(path, nil, unescape); handlers != nil {
c.handlers = engine.allNoMethod c.handlers = engine.allNoMethod
serveError(c, http.StatusMethodNotAllowed, default405Body) serveError(c, http.StatusMethodNotAllowed, default405Body)
@ -383,7 +386,6 @@ func (engine *Engine) handleHTTPRequest(c *Context) {
} }
} }
} }
}
c.handlers = engine.allNoRoute c.handlers = engine.allNoRoute
serveError(c, http.StatusNotFound, default404Body) serveError(c, http.StatusNotFound, default404Body)
} }
@ -393,14 +395,16 @@ var mimePlain = []string{MIMEPlain}
func serveError(c *Context, code int, defaultMessage []byte) { func serveError(c *Context, code int, defaultMessage []byte) {
c.writermem.status = code c.writermem.status = code
c.Next() c.Next()
if !c.writermem.Written() { if c.writermem.Written() {
return
}
if c.writermem.Status() == code { if c.writermem.Status() == code {
c.writermem.Header()["Content-Type"] = mimePlain c.writermem.Header()["Content-Type"] = mimePlain
c.Writer.Write(defaultMessage) c.Writer.Write(defaultMessage)
} else { return
}
c.writermem.WriteHeaderNow() c.writermem.WriteHeaderNow()
} return
}
} }
func redirectTrailingSlash(c *Context) { func redirectTrailingSlash(c *Context) {
@ -411,10 +415,9 @@ func redirectTrailingSlash(c *Context) {
code = http.StatusTemporaryRedirect code = http.StatusTemporaryRedirect
} }
req.URL.Path = path + "/"
if length := len(path); length > 1 && path[length-1] == '/' { if length := len(path); length > 1 && path[length-1] == '/' {
req.URL.Path = path[:length-1] req.URL.Path = path[:length-1]
} else {
req.URL.Path = path + "/"
} }
debugPrint("redirecting request %d: %s --> %s", code, path, req.URL.String()) debugPrint("redirecting request %d: %s --> %s", code, path, req.URL.String())
http.Redirect(c.Writer, req, req.URL.String(), code) http.Redirect(c.Writer, req, req.URL.String(), code)

View File

@ -59,11 +59,11 @@ func cleanPath(p string) string {
case p[r] == '.' && p[r+1] == '/': case p[r] == '.' && p[r+1] == '/':
// . element // . element
r++ r += 2
case p[r] == '.' && p[r+1] == '.' && (r+2 == n || p[r+2] == '/'): case p[r] == '.' && p[r+1] == '.' && (r+2 == n || p[r+2] == '/'):
// .. element: remove to last / // .. element: remove to last /
r += 2 r += 3
if w > 1 { if w > 1 {
// can backtrack // can backtrack

View File

@ -333,6 +333,16 @@ func TestRouteNotAllowedEnabled(t *testing.T) {
assert.Equal(t, http.StatusTeapot, w.Code) assert.Equal(t, http.StatusTeapot, w.Code)
} }
func TestRouteNotAllowedEnabled2(t *testing.T) {
router := New()
router.HandleMethodNotAllowed = true
// add one methodTree to trees
router.addRoute("POST", "/", HandlersChain{func(_ *Context) {}})
router.GET("/path2", func(c *Context) {})
w := performRequest(router, "POST", "/path2")
assert.Equal(t, http.StatusMethodNotAllowed, w.Code)
}
func TestRouteNotAllowedDisabled(t *testing.T) { func TestRouteNotAllowedDisabled(t *testing.T) {
router := New() router := New()
router.HandleMethodNotAllowed = false router.HandleMethodNotAllowed = false

View File

@ -4,10 +4,7 @@
package gin package gin
import ( import "net/http"
"net/http"
"net/http/httptest"
)
// CreateTestContext returns a fresh engine and context for testing purposes // CreateTestContext returns a fresh engine and context for testing purposes
func CreateTestContext(w http.ResponseWriter) (c *Context, r *Engine) { func CreateTestContext(w http.ResponseWriter) (c *Context, r *Engine) {
@ -17,23 +14,3 @@ func CreateTestContext(w http.ResponseWriter) (c *Context, r *Engine) {
c.writermem.reset(w) c.writermem.reset(w)
return return
} }
type TestResponseRecorder struct {
*httptest.ResponseRecorder
closeChannel chan bool
}
func (r *TestResponseRecorder) CloseNotify() <-chan bool {
return r.closeChannel
}
func (r *TestResponseRecorder) closeClient() {
r.closeChannel <- true
}
func CreateTestResponseRecorder() *TestResponseRecorder {
return &TestResponseRecorder{
httptest.NewRecorder(),
make(chan bool, 1),
}
}