diff --git a/context_test.go b/context_test.go index 1dec902c..a7d741ea 100644 --- a/context_test.go +++ b/context_test.go @@ -1440,7 +1440,6 @@ func TestContextAbortWithError(t *testing.T) { func TestContextClientIP(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) c.Request, _ = http.NewRequest("POST", "/", nil) - c.engine.trustedCIDRs, _ = c.engine.prepareTrustedCIDRs() resetContextForClientIPTests(c) // Legacy tests (validating that the defaults don't break the diff --git a/gin.go b/gin.go index 3cb9aced..6362bf24 100644 --- a/gin.go +++ b/gin.go @@ -387,33 +387,22 @@ func (engine *Engine) Run(addr ...string) (err error) { return } -func (engine *Engine) prepareTrustedCIDRs() ([]*net.IPNet, error) { - if engine.trustedProxies == nil { +func prepareTrustedCIDRs(trustedProxies []string) ([]*net.IPNet, error) { + if trustedProxies == nil { return nil, nil } - cidr := make([]*net.IPNet, 0, len(engine.trustedProxies)) - for _, trustedProxy := range engine.trustedProxies { - if !strings.Contains(trustedProxy, "/") { - ip := parseIP(trustedProxy) - if ip == nil { - return cidr, &net.ParseError{Type: "IP address", Text: trustedProxy} - } + cidrArr := make([]*net.IPNet, 0, len(trustedProxies)) - switch len(ip) { - case net.IPv4len: - trustedProxy += "/32" - case net.IPv6len: - trustedProxy += "/128" - } - } - _, cidrNet, err := net.ParseCIDR(trustedProxy) + for _, trustedProxy := range trustedProxies { + cidrNet, err := prepareCIDR(trustedProxy) if err != nil { - return cidr, err + return cidrArr, err } - cidr = append(cidr, cidrNet) + cidrArr = append(cidrArr, cidrNet) } - return cidr, nil + + return cidrArr, nil } // SetTrustedProxies set a list of network origins (IPv4 addresses, @@ -426,7 +415,35 @@ func (engine *Engine) prepareTrustedCIDRs() ([]*net.IPNet, error) { // return the remote address directly. func (engine *Engine) SetTrustedProxies(trustedProxies []string) error { engine.trustedProxies = trustedProxies - return engine.parseTrustedProxies() + trustedCIDRs, err := prepareTrustedCIDRs(trustedProxies) + engine.trustedCIDRs = trustedCIDRs + return err +} + +// converts a string CIDR or single IP to a IPNet instance +func prepareCIDR(ipOrCidr string) (*net.IPNet, error) { + // not a CIDR, try to convert to cidr notation + if !strings.Contains(ipOrCidr, "/") { + ip := parseIP(ipOrCidr) + if ip == nil { + return nil, &net.ParseError{Type: "IP address", Text: ipOrCidr} + } + + switch len(ip) { + case net.IPv4len: + ipOrCidr += "/32" + case net.IPv6len: + ipOrCidr += "/128" + } + } + + _, cidrNet, err := net.ParseCIDR(ipOrCidr) + + if err != nil { + return nil, err + } + + return cidrNet, err } // isUnsafeTrustedProxies checks if Engine.trustedCIDRs contains all IPs, it's not safe if it has (returns true) @@ -434,13 +451,6 @@ func (engine *Engine) isUnsafeTrustedProxies() bool { return engine.isTrustedProxy(net.ParseIP("0.0.0.0")) || engine.isTrustedProxy(net.ParseIP("::")) } -// parseTrustedProxies parse Engine.trustedProxies to Engine.trustedCIDRs -func (engine *Engine) parseTrustedProxies() error { - trustedCIDRs, err := engine.prepareTrustedCIDRs() - engine.trustedCIDRs = trustedCIDRs - return err -} - // isTrustedProxy will check whether the IP address is included in the trusted list according to Engine.trustedCIDRs func (engine *Engine) isTrustedProxy(ip net.IP) bool { if engine.trustedCIDRs == nil { diff --git a/logger_test.go b/logger_test.go index 5f78708f..186a305a 100644 --- a/logger_test.go +++ b/logger_test.go @@ -183,10 +183,7 @@ func TestLoggerWithConfigFormatting(t *testing.T) { var gotParam LogFormatterParams var gotKeys map[string]any buffer := new(strings.Builder) - router := New() - router.engine.trustedCIDRs, _ = router.engine.prepareTrustedCIDRs() - router.Use(LoggerWithConfig(LoggerConfig{ Output: buffer, Formatter: func(param LogFormatterParams) string {