improve wrap to add params in native context

This commit is contained in:
qneyrat 2018-06-29 00:08:04 +02:00
parent eb9f313144
commit e6fd26cea5
4 changed files with 97 additions and 2 deletions

26
native_context.go Normal file
View File

@ -0,0 +1,26 @@
// +build go1.7
package gin
import (
"context"
"net/http"
)
const ParamsKey = "_gin-gonic/gin/paramskey"
// WithParams is a helper function to add Params in native context
// Returns a http request
func WithParams(r *http.Request, params Params) *http.Request {
ctx := context.WithValue(r.Context(), ParamsKey, params)
return r.WithContext(ctx)
}
// GetParams is a helper function to get Params in native context
// Returns a Gin Params
func GetParams(r *http.Request) Params {
if params := r.Context().Value(ParamsKey); params != nil {
return params.(Params)
}
return nil
}

17
native_context_1.6.go Normal file
View File

@ -0,0 +1,17 @@
// +build !go1.7
package gin
import "net/http"
// WithParams is a helper function to add Params in native context
// Returns a http request
func WithParams(r *http.Request, params Params) *http.Request {
return r
}
// GetParams is a helper function to get Params in native context
// Returns a Gin Params
func GetParams(r *http.Request) Params {
return nil
}

42
native_context_test.go Normal file
View File

@ -0,0 +1,42 @@
// +build go1.7
package gin
import (
"net/http"
"testing"
"github.com/stretchr/testify/assert"
)
func TestGetParamsWithWrap(t *testing.T) {
router := New()
router.GET("/hello/:name", WrapH(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
assert.Equal(t, "/hello/gopher", req.URL.Path)
assert.Equal(t, "gopher", GetParams(req).ByName("name"))
})))
router.GET("/hello2/:name", WrapF(func(w http.ResponseWriter, req *http.Request) {
assert.Equal(t, "/hello2/gopher", req.URL.Path)
assert.Equal(t, "gopher", GetParams(req).ByName("name"))
}))
router.GET("/hello", WrapH(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
assert.Equal(t, "/hello", req.URL.Path)
assert.Equal(t, "", GetParams(req).ByName("name"))
})))
w := performRequest(router, "GET", "/hello/gopher")
assert.Equal(t, 200, w.Code)
w = performRequest(router, "GET", "/hello2/gopher")
assert.Equal(t, 200, w.Code)
w = performRequest(router, "GET", "/hello")
assert.Equal(t, 200, w.Code)
}
func TestGetParamsWithRequest(t *testing.T) {
req := &http.Request{}
assert.Equal(t, "", GetParams(req).ByName("name"))
}

View File

@ -37,7 +37,12 @@ func Bind(val interface{}) HandlerFunc {
// Returns a Gin middleware
func WrapF(f http.HandlerFunc) HandlerFunc {
return func(c *Context) {
f(c.Writer, c.Request)
r := c.Request
if len(c.Params) > 0 {
r = WithParams(r, c.Params)
}
f(c.Writer, r)
}
}
@ -45,7 +50,12 @@ func WrapF(f http.HandlerFunc) HandlerFunc {
// Returns a Gin middleware
func WrapH(h http.Handler) HandlerFunc {
return func(c *Context) {
h.ServeHTTP(c.Writer, c.Request)
r := c.Request
if len(c.Params) > 0 {
r = WithParams(r, c.Params)
}
h.ServeHTTP(c.Writer, r)
}
}