From 94dee576bfda2e05a8feaadb82fffc23c196ce57 Mon Sep 17 00:00:00 2001 From: fan Date: Sun, 5 May 2019 19:04:54 +0800 Subject: [PATCH] add support http forward --- context.go | 21 +++++++++++++++++++++ context_test.go | 38 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 57 insertions(+), 2 deletions(-) diff --git a/context.go b/context.go index 5dc7f8a0..db18187b 100644 --- a/context.go +++ b/context.go @@ -13,6 +13,7 @@ import ( "mime/multipart" "net" "net/http" + "net/http/httputil" "net/url" "os" "strings" @@ -1026,3 +1027,23 @@ func (c *Context) Value(key interface{}) interface{} { } return nil } + +// ReverseProxy is an HTTP Handler that takes an incoming request and +// sends it to another server, proxying the response back to the client. +// You can use it to forward. +func (c *Context) Forward(target string) { + host := c.Request.Host + scheme := c.Request.URL.Scheme + if scheme == "" { + scheme = "http" + } + var proxy = httputil.ReverseProxy{ + Director: func(req *http.Request) { + req.URL.Scheme = scheme + req.URL.Host = host + req.Host = host + }, + } + c.Request.URL.Path = target + proxy.ServeHTTP(c.Writer, c.Request) +} diff --git a/context_test.go b/context_test.go index 0da5fbe6..43b7499a 100644 --- a/context_test.go +++ b/context_test.go @@ -10,21 +10,22 @@ import ( "fmt" "html/template" "io" + "io/ioutil" "mime/multipart" "net/http" "net/http/httptest" "reflect" "strings" + "sync" "testing" "time" "github.com/gin-contrib/sse" "github.com/gin-gonic/gin/binding" + testdata "github.com/gin-gonic/gin/testdata/protoexample" "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" "golang.org/x/net/context" - - testdata "github.com/gin-gonic/gin/testdata/protoexample" ) var _ context.Context = &Context{} @@ -1812,3 +1813,36 @@ func TestContextResetInHandler(t *testing.T) { c.Next() }) } + +func TestContext_Forward(t *testing.T) { + g := sync.WaitGroup{} + g.Add(1) + go func(g *sync.WaitGroup) { + g.Done() + e := Default() + e.GET("/test", func(c *Context) { + c.Forward("/test2") + }) + e.GET("/test2", func(c *Context) { + p := c.Query("p") + c.String(http.StatusOK, p) + + }) + e.Run(":9998") + + }(&g) + + g.Wait() + + p := "test" + resp, err := http.Get("http://127.0.0.1:9998/test?p=" + p) + assert.NoError(t, err) + defer resp.Body.Close() + + bytes, err := ioutil.ReadAll(resp.Body) + + assert.NoError(t, err) + + assert.Equal(t, p, string(bytes)) + +}