diff --git a/context.go b/context.go index 3e040761..037ea8fe 100644 --- a/context.go +++ b/context.go @@ -741,6 +741,7 @@ func (c *Context) ClientIP() string { if remoteIP == nil { return "" } + if trusted && c.engine.ForwardedByClientIP && c.engine.RemoteIPHeaders != nil { for _, headerName := range c.engine.RemoteIPHeaders { ip, valid := validateHeader(c.requestHeader(headerName)) @@ -765,13 +766,19 @@ func (c *Context) RemoteIP() (net.IP, bool) { if remoteIP == nil { return nil, false } - if c.engine.trustedCIDRs != nil { - for _, cidr := range c.engine.trustedCIDRs { - if cidr.Contains(remoteIP) { - return remoteIP, true + + trustedCIDRs, err := c.engine.prepareTrustedCIDRs() + if err == nil { + c.engine.trustedCIDRs = trustedCIDRs + if c.engine.trustedCIDRs != nil { + for _, cidr := range c.engine.trustedCIDRs { + if cidr.Contains(remoteIP) { + return remoteIP, true + } } } } + return remoteIP, false } diff --git a/context_test.go b/context_test.go index cb1f9c5b..9570a9e9 100644 --- a/context_test.go +++ b/context_test.go @@ -1430,7 +1430,7 @@ func TestContextClientIP(t *testing.T) { // Only trust RemoteAddr c.engine.TrustedProxies = []string{"40.40.40.40"} - assert.Equal(t, "30.30.30.30", c.ClientIP()) + assert.Equal(t, "20.20.20.20", c.ClientIP()) // All steps are trusted c.engine.TrustedProxies = []string{"40.40.40.40", "30.30.30.30", "20.20.20.20"} @@ -1442,7 +1442,7 @@ func TestContextClientIP(t *testing.T) { // Use hostname that resolves to all the proxies c.engine.TrustedProxies = []string{"foo"} - assert.Equal(t, "20.20.20.20", c.ClientIP()) + assert.Equal(t, "40.40.40.40", c.ClientIP()) // Use hostname that returns an error c.engine.TrustedProxies = []string{"bar"} diff --git a/gin.go b/gin.go index 2058c2d5..03a0e127 100644 --- a/gin.go +++ b/gin.go @@ -338,31 +338,32 @@ func (engine *Engine) Run(addr ...string) (err error) { } func (engine *Engine) prepareTrustedCIDRs() ([]*net.IPNet, error) { - if engine.TrustedProxies != nil { - cidr := make([]*net.IPNet, 0, len(engine.TrustedProxies)) - for _, trustedProxy := range engine.TrustedProxies { - if !strings.Contains(trustedProxy, "/") { - ip := parseIP(trustedProxy) - if ip == nil { - return cidr, &net.ParseError{Type: "IP address", Text: trustedProxy} - } - - switch len(ip) { - case net.IPv4len: - trustedProxy += "/32" - case net.IPv6len: - trustedProxy += "/128" - } - } - _, cidrNet, err := net.ParseCIDR(trustedProxy) - if err != nil { - return cidr, err - } - cidr = append(cidr, cidrNet) - } - return cidr, nil + if engine.TrustedProxies == nil { + return nil, nil } - return nil, nil + + cidr := make([]*net.IPNet, 0, len(engine.TrustedProxies)) + for _, trustedProxy := range engine.TrustedProxies { + if !strings.Contains(trustedProxy, "/") { + ip := parseIP(trustedProxy) + if ip == nil { + return cidr, &net.ParseError{Type: "IP address", Text: trustedProxy} + } + + switch len(ip) { + case net.IPv4len: + trustedProxy += "/32" + case net.IPv6len: + trustedProxy += "/128" + } + } + _, cidrNet, err := net.ParseCIDR(trustedProxy) + if err != nil { + return cidr, err + } + cidr = append(cidr, cidrNet) + } + return cidr, nil } // parseIP parse a string representation of an IP and returns a net.IP with the