diff --git a/context.go b/context.go index 9e7a6578..9fa3cb81 100644 --- a/context.go +++ b/context.go @@ -729,13 +729,13 @@ func (c *Context) ShouldBindBodyWith(obj interface{}, bb binding.BindingBody) (e // X-Real-IP and X-Forwarded-For in order to work properly with reverse-proxies such us: nginx or haproxy. // Use X-Forwarded-For before X-Real-Ip as nginx uses X-Real-Ip with the proxy's IP. func (c *Context) ClientIP() string { - if c.engine.ForwardedByClientIP && c.engine.RemoteIPHeaders != nil { - for _, header := range c.engine.RemoteIPHeaders { - ipChain := filterIPsFromUntrustedProxies(c.requestHeader(header), c.Request, c.engine) - if len(ipChain) > 0 { - return ipChain[0] - } - } + ip, _, err := net.SplitHostPort(strings.TrimSpace(c.Request.RemoteAddr)) + if err != nil { + return "" + } + remoteIP := net.ParseIP(ip) + if remoteIP == nil { + return "" } if c.engine.AppEngine { @@ -744,82 +744,46 @@ func (c *Context) ClientIP() string { } } - ip, _ := getTransportPeerIPForRequest(c.Request) + if c.shouldCheckIPHeaders() { + for _, cidr := range c.engine.trustedCIDRs { + if cidr.Contains(remoteIP) { + for _, headerName := range c.engine.RemoteIPHeaders { + ip, valid := validateHeader(c.requestHeader(headerName)) + if valid { + return ip + } + } + } + } + } - return ip + return remoteIP.String() } -func filterIPsFromUntrustedProxies(XForwardedForHeader string, req *http.Request, e *Engine) []string { - var items, out []string - if XForwardedForHeader != "" { - items = strings.Split(XForwardedForHeader, ",") - } else { - return []string{} - } - if peerIP, err := getTransportPeerIPForRequest(req); err == nil { - items = append(items, peerIP) - } +func (c *Context) shouldCheckIPHeaders() bool { + return c.engine.ForwardedByClientIP && + c.engine.RemoteIPHeaders != nil && + len(c.engine.RemoteIPHeaders) > 0 && + c.engine.trustedCIDRs != nil +} - for i := len(items) - 1; i >= 0; i-- { - item := strings.TrimSpace(items[i]) - ip := net.ParseIP(item) +func validateHeader(header string) (clientIP string, valid bool) { + if header == "" { + return + } + items := strings.Split(header, ",") + for i, ipStr := range items { + ipStr = strings.TrimSpace(ipStr) + ip := net.ParseIP(ipStr) if ip == nil { - return out + return "", false } - - out = prependString(ip.String(), out) - if !isTrustedProxy(ip, e) { - return out - } - // out = prependString(ip.String(), out) - } - return out -} - -func isTrustedProxy(ip net.IP, e *Engine) bool { - for _, trustedProxy := range e.TrustedProxies { - if _, ipnet, err := net.ParseCIDR(trustedProxy); err == nil { - if ipnet.Contains(ip) { - return true - } - continue - } - - if proxyIP := net.ParseIP(trustedProxy); proxyIP != nil { - if proxyIP.Equal(ip) { - return true - } - continue - } - - if addrs, err := e.lookupHost(trustedProxy); err == nil { - for _, proxyAddr := range addrs { - proxyIP := net.ParseIP(proxyAddr) - if proxyIP == nil { - continue - } - if proxyIP.Equal(ip) { - return true - } - } + if i == 0 { + clientIP = ipStr + valid = true } } - return false -} - -func prependString(ip string, ipList []string) []string { - ipList = append(ipList, "") - copy(ipList[1:], ipList) - ipList[0] = string(ip) - return ipList -} - -func getTransportPeerIPForRequest(req *http.Request) (string, error) { - var err error - if ip, _, err := net.SplitHostPort(strings.TrimSpace(req.RemoteAddr)); err == nil { - return ip, nil - } - return "", err + return } // ContentType returns the Content-Type header of the request. diff --git a/gin.go b/gin.go index 876bfe8a..877d3238 100644 --- a/gin.go +++ b/gin.go @@ -11,6 +11,7 @@ import ( "net/http" "os" "path" + "strings" "sync" "github.com/gin-gonic/gin/internal/bytesconv" @@ -118,6 +119,7 @@ type Engine struct { pool sync.Pool trees methodTrees maxParams uint16 + trustedCIDRs []*net.IPNet } var _ IRouter = &Engine{} @@ -312,12 +314,35 @@ func iterate(path, method string, routes RoutesInfo, root *node) RoutesInfo { func (engine *Engine) Run(addr ...string) (err error) { defer func() { debugPrintError(err) }() + trustedCIDRs, err := engine.prepareCIDR() + if err != nil { + return err + } + engine.trustedCIDRs = trustedCIDRs address := resolveAddress(addr) debugPrint("Listening and serving HTTP on %s\n", address) err = http.ListenAndServe(address, engine) return } +func (engine *Engine) prepareCIDR() ([]*net.IPNet, error) { + if engine.TrustedProxies != nil { + cidr := make([]*net.IPNet, len(engine.TrustedProxies), 0) + for _, trustedProxy := range engine.TrustedProxies { + if strings.Contains(trustedProxy, "/") { + trustedProxy += "/32" + } + _, cidrNet, err := net.ParseCIDR(trustedProxy) + if err != nil { + return cidr, err + } + cidr = append(cidr, cidrNet) + } + return cidr, nil + } + return nil, nil +} + // RunTLS attaches the router to a http.Server and starts listening and serving HTTPS (secure) requests. // It is a shortcut for http.ListenAndServeTLS(addr, certFile, keyFile, router) // Note: this method will block the calling goroutine indefinitely unless an error happens.