refactor: prepare trusted proxies

This commit is contained in:
Stefan Bildl 2023-03-25 15:56:45 +01:00
parent fc3aee84b5
commit 12d42118de
3 changed files with 38 additions and 32 deletions

View File

@ -1440,7 +1440,6 @@ func TestContextAbortWithError(t *testing.T) {
func TestContextClientIP(t *testing.T) { func TestContextClientIP(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder()) c, _ := CreateTestContext(httptest.NewRecorder())
c.Request, _ = http.NewRequest("POST", "/", nil) c.Request, _ = http.NewRequest("POST", "/", nil)
c.engine.trustedCIDRs, _ = c.engine.prepareTrustedCIDRs()
resetContextForClientIPTests(c) resetContextForClientIPTests(c)
// Legacy tests (validating that the defaults don't break the // Legacy tests (validating that the defaults don't break the

66
gin.go
View File

@ -387,33 +387,22 @@ func (engine *Engine) Run(addr ...string) (err error) {
return return
} }
func (engine *Engine) prepareTrustedCIDRs() ([]*net.IPNet, error) { func prepareTrustedCIDRs(trustedProxies []string) ([]*net.IPNet, error) {
if engine.trustedProxies == nil { if trustedProxies == nil {
return nil, nil return nil, nil
} }
cidr := make([]*net.IPNet, 0, len(engine.trustedProxies)) cidrArr := make([]*net.IPNet, 0, len(trustedProxies))
for _, trustedProxy := range engine.trustedProxies {
if !strings.Contains(trustedProxy, "/") {
ip := parseIP(trustedProxy)
if ip == nil {
return cidr, &net.ParseError{Type: "IP address", Text: trustedProxy}
}
switch len(ip) { for _, trustedProxy := range trustedProxies {
case net.IPv4len: cidrNet, err := prepareCIDR(trustedProxy)
trustedProxy += "/32"
case net.IPv6len:
trustedProxy += "/128"
}
}
_, cidrNet, err := net.ParseCIDR(trustedProxy)
if err != nil { if err != nil {
return cidr, err return cidrArr, err
} }
cidr = append(cidr, cidrNet) cidrArr = append(cidrArr, cidrNet)
} }
return cidr, nil
return cidrArr, nil
} }
// SetTrustedProxies set a list of network origins (IPv4 addresses, // SetTrustedProxies set a list of network origins (IPv4 addresses,
@ -426,7 +415,35 @@ func (engine *Engine) prepareTrustedCIDRs() ([]*net.IPNet, error) {
// return the remote address directly. // return the remote address directly.
func (engine *Engine) SetTrustedProxies(trustedProxies []string) error { func (engine *Engine) SetTrustedProxies(trustedProxies []string) error {
engine.trustedProxies = trustedProxies engine.trustedProxies = trustedProxies
return engine.parseTrustedProxies() trustedCIDRs, err := prepareTrustedCIDRs(trustedProxies)
engine.trustedCIDRs = trustedCIDRs
return err
}
// converts a string CIDR or single IP to a IPNet instance
func prepareCIDR(ipOrCidr string) (*net.IPNet, error) {
// not a CIDR, try to convert to cidr notation
if !strings.Contains(ipOrCidr, "/") {
ip := parseIP(ipOrCidr)
if ip == nil {
return nil, &net.ParseError{Type: "IP address", Text: ipOrCidr}
}
switch len(ip) {
case net.IPv4len:
ipOrCidr += "/32"
case net.IPv6len:
ipOrCidr += "/128"
}
}
_, cidrNet, err := net.ParseCIDR(ipOrCidr)
if err != nil {
return nil, err
}
return cidrNet, err
} }
// isUnsafeTrustedProxies checks if Engine.trustedCIDRs contains all IPs, it's not safe if it has (returns true) // isUnsafeTrustedProxies checks if Engine.trustedCIDRs contains all IPs, it's not safe if it has (returns true)
@ -434,13 +451,6 @@ func (engine *Engine) isUnsafeTrustedProxies() bool {
return engine.isTrustedProxy(net.ParseIP("0.0.0.0")) || engine.isTrustedProxy(net.ParseIP("::")) return engine.isTrustedProxy(net.ParseIP("0.0.0.0")) || engine.isTrustedProxy(net.ParseIP("::"))
} }
// parseTrustedProxies parse Engine.trustedProxies to Engine.trustedCIDRs
func (engine *Engine) parseTrustedProxies() error {
trustedCIDRs, err := engine.prepareTrustedCIDRs()
engine.trustedCIDRs = trustedCIDRs
return err
}
// isTrustedProxy will check whether the IP address is included in the trusted list according to Engine.trustedCIDRs // isTrustedProxy will check whether the IP address is included in the trusted list according to Engine.trustedCIDRs
func (engine *Engine) isTrustedProxy(ip net.IP) bool { func (engine *Engine) isTrustedProxy(ip net.IP) bool {
if engine.trustedCIDRs == nil { if engine.trustedCIDRs == nil {

View File

@ -183,10 +183,7 @@ func TestLoggerWithConfigFormatting(t *testing.T) {
var gotParam LogFormatterParams var gotParam LogFormatterParams
var gotKeys map[string]any var gotKeys map[string]any
buffer := new(strings.Builder) buffer := new(strings.Builder)
router := New() router := New()
router.engine.trustedCIDRs, _ = router.engine.prepareTrustedCIDRs()
router.Use(LoggerWithConfig(LoggerConfig{ router.Use(LoggerWithConfig(LoggerConfig{
Output: buffer, Output: buffer,
Formatter: func(param LogFormatterParams) string { Formatter: func(param LogFormatterParams) string {