diff --git a/gin.go b/gin.go index 1633fe13..18afdb51 100644 --- a/gin.go +++ b/gin.go @@ -17,6 +17,7 @@ import ( "github.com/gin-gonic/gin/internal/bytesconv" "github.com/gin-gonic/gin/render" + "github.com/gin-gonic/gin/utils" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" ) @@ -401,24 +402,17 @@ func (engine *Engine) Run(addr ...string) (err error) { } func (engine *Engine) prepareTrustedCIDRs() ([]*net.IPNet, error) { + var err error if engine.trustedProxies == nil { return nil, nil } cidr := make([]*net.IPNet, 0, len(engine.trustedProxies)) for _, trustedProxy := range engine.trustedProxies { - if !strings.Contains(trustedProxy, "/") { - ip := parseIP(trustedProxy) - if ip == nil { - return cidr, &net.ParseError{Type: "IP address", Text: trustedProxy} - } - - switch len(ip) { - case net.IPv4len: - trustedProxy += "/32" - case net.IPv6len: - trustedProxy += "/128" - } + trustedProxy, err = utils.MakeTrustIP(trustedProxy) + + if err != nil { + return cidr, err } _, cidrNet, err := net.ParseCIDR(trustedProxy) if err != nil { @@ -489,20 +483,6 @@ func (engine *Engine) validateHeader(header string) (clientIP string, valid bool return "", false } -// parseIP parse a string representation of an IP and returns a net.IP with the -// minimum byte representation or nil if input is invalid. -func parseIP(ip string) net.IP { - parsedIP := net.ParseIP(ip) - - if ipv4 := parsedIP.To4(); ipv4 != nil { - // return ip in a 4-byte representation - return ipv4 - } - - // return ip in a 16-byte representation or nil - return parsedIP -} - // RunTLS attaches the router to a http.Server and starts listening and serving HTTPS (secure) requests. // It is a shortcut for http.ListenAndServeTLS(addr, certFile, keyFile, router) // Note: this method will block the calling goroutine indefinitely unless an error happens. diff --git a/utils/ip.go b/utils/ip.go new file mode 100644 index 00000000..9f4d54b9 --- /dev/null +++ b/utils/ip.go @@ -0,0 +1,46 @@ +package utils + +import ( + "net" + "strings" +) + +func parseIP(ip string) (net.IP, error) { + parsedIP := net.ParseIP(ip) + + if ipv4 := parsedIP.To4(); ipv4 != nil { + return ipv4, nil + } + if parsedIP != nil{ + return parsedIP, nil + } + return nil, &net.ParseError{Type: "IP address", Text: ip} +} + +func MakeTrustIP(trustedIP string) (string, error) { + if strings.Contains(trustedIP, "/") { + return trustedIP, nil + } + ip, err := parseIP(trustedIP) + + if err != nil { + return "", err + } + + var mapRenderIP = map [int]func(trustIP string) string{ + net.IPv4len: func(trustIP string) string{ + return trustIP + "/32" + }, + net.IPv6len: func(trustIP string) string{ + return trustIP + "/32" + }, + } + + fn, isExistKey := mapRenderIP[len(ip)] + + if isExistKey != true{ + return "", &net.ParseError{Type: "IP address", Text: trustedIP} + } + + return fn(trustedIP), nil +} \ No newline at end of file