diff --git a/native_context.go b/native_context.go new file mode 100644 index 00000000..51e324f0 --- /dev/null +++ b/native_context.go @@ -0,0 +1,26 @@ +// +build go1.7 + +package gin + +import ( + "context" + "net/http" +) + +const ParamsKey = "_gin-gonic/gin/paramskey" + +// WithParams is a helper function to add Params in native context +// Returns a http request +func WithParams(r *http.Request, params Params) *http.Request { + ctx := context.WithValue(r.Context(), ParamsKey, params) + return r.WithContext(ctx) +} + +// GetParams is a helper function to get Params in native context +// Returns a Gin Params +func GetParams(r *http.Request) Params { + if params := r.Context().Value(ParamsKey); params != nil { + return params.(Params) + } + return nil +} diff --git a/native_context_1.6.go b/native_context_1.6.go new file mode 100644 index 00000000..45703a7a --- /dev/null +++ b/native_context_1.6.go @@ -0,0 +1,17 @@ +// +build !go1.7 + +package gin + +import "net/http" + +// WithParams is a helper function to add Params in native context +// Returns a http request +func WithParams(r *http.Request, params Params) *http.Request { + return r +} + +// GetParams is a helper function to get Params in native context +// Returns a Gin Params +func GetParams(r *http.Request) Params { + return nil +} diff --git a/native_context_test.go b/native_context_test.go new file mode 100644 index 00000000..6c042694 --- /dev/null +++ b/native_context_test.go @@ -0,0 +1,42 @@ +// +build go1.7 + +package gin + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGetParamsWithWrap(t *testing.T) { + router := New() + router.GET("/hello/:name", WrapH(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + assert.Equal(t, "/hello/gopher", req.URL.Path) + assert.Equal(t, "gopher", GetParams(req).ByName("name")) + }))) + + router.GET("/hello2/:name", WrapF(func(w http.ResponseWriter, req *http.Request) { + assert.Equal(t, "/hello2/gopher", req.URL.Path) + assert.Equal(t, "gopher", GetParams(req).ByName("name")) + })) + + router.GET("/hello", WrapH(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + assert.Equal(t, "/hello", req.URL.Path) + assert.Equal(t, "", GetParams(req).ByName("name")) + }))) + + w := performRequest(router, "GET", "/hello/gopher") + assert.Equal(t, 200, w.Code) + + w = performRequest(router, "GET", "/hello2/gopher") + assert.Equal(t, 200, w.Code) + + w = performRequest(router, "GET", "/hello") + assert.Equal(t, 200, w.Code) +} + +func TestGetParamsWithRequest(t *testing.T) { + req := &http.Request{} + assert.Equal(t, "", GetParams(req).ByName("name")) +} diff --git a/utils.go b/utils.go index bf32c775..f82c7759 100644 --- a/utils.go +++ b/utils.go @@ -37,7 +37,12 @@ func Bind(val interface{}) HandlerFunc { // Returns a Gin middleware func WrapF(f http.HandlerFunc) HandlerFunc { return func(c *Context) { - f(c.Writer, c.Request) + r := c.Request + if len(c.Params) > 0 { + r = WithParams(r, c.Params) + } + + f(c.Writer, r) } } @@ -45,7 +50,12 @@ func WrapF(f http.HandlerFunc) HandlerFunc { // Returns a Gin middleware func WrapH(h http.Handler) HandlerFunc { return func(c *Context) { - h.ServeHTTP(c.Writer, c.Request) + r := c.Request + if len(c.Params) > 0 { + r = WithParams(r, c.Params) + } + + h.ServeHTTP(c.Writer, r) } }