refactor move logic to remoteIP()

This commit is contained in:
Manu Mtz.-Almeida 2021-02-08 14:08:35 +01:00
parent e14a43cc4c
commit 55ad88a12b

View File

@ -729,42 +729,50 @@ func (c *Context) ShouldBindBodyWith(obj interface{}, bb binding.BindingBody) (e
// X-Real-IP and X-Forwarded-For in order to work properly with reverse-proxies such us: nginx or haproxy. // X-Real-IP and X-Forwarded-For in order to work properly with reverse-proxies such us: nginx or haproxy.
// Use X-Forwarded-For before X-Real-Ip as nginx uses X-Real-Ip with the proxy's IP. // Use X-Forwarded-For before X-Real-Ip as nginx uses X-Real-Ip with the proxy's IP.
func (c *Context) ClientIP() string { func (c *Context) ClientIP() string {
ip, _, err := net.SplitHostPort(strings.TrimSpace(c.Request.RemoteAddr))
if err != nil {
return ""
}
remoteIP := net.ParseIP(ip)
if remoteIP == nil {
return ""
}
if c.engine.AppEngine { if c.engine.AppEngine {
if addr := c.requestHeader("X-Appengine-Remote-Addr"); addr != "" { if addr := c.requestHeader("X-Appengine-Remote-Addr"); addr != "" {
return addr return addr
} }
} }
if c.shouldCheckIPHeaders() { remoteIP, trusted := c.RemoteIP()
for _, cidr := range c.engine.trustedCIDRs { if remoteIP == nil {
if cidr.Contains(remoteIP) { return ""
for _, headerName := range c.engine.RemoteIPHeaders { }
ip, valid := validateHeader(c.requestHeader(headerName)) if trusted {
if valid { for _, headerName := range c.engine.RemoteIPHeaders {
return ip ip, valid := validateHeader(c.requestHeader(headerName))
} if valid {
} return ip
} }
} }
} }
return remoteIP.String() return remoteIP.String()
} }
func (c *Context) shouldCheckIPHeaders() bool { func (c *Context) RemoteIP() (net.IP, bool) {
return c.engine.ForwardedByClientIP && ip, _, err := net.SplitHostPort(strings.TrimSpace(c.Request.RemoteAddr))
if err != nil {
return nil, false
}
remoteIP := net.ParseIP(ip)
if remoteIP == nil {
return nil, false
}
shouldCheckTrustedIP := c.engine.ForwardedByClientIP &&
c.engine.RemoteIPHeaders != nil && c.engine.RemoteIPHeaders != nil &&
len(c.engine.RemoteIPHeaders) > 0 && len(c.engine.RemoteIPHeaders) > 0 &&
c.engine.trustedCIDRs != nil c.engine.trustedCIDRs != nil
if shouldCheckTrustedIP {
for _, cidr := range c.engine.trustedCIDRs {
if cidr.Contains(remoteIP) {
return remoteIP, true
}
}
}
return remoteIP, false
} }
func validateHeader(header string) (clientIP string, valid bool) { func validateHeader(header string) (clientIP string, valid bool) {