refactor: prepare trusted proxies

This commit is contained in:
Stefan Bildl 2023-03-25 15:56:45 +01:00
parent fc3aee84b5
commit 12d42118de
3 changed files with 38 additions and 32 deletions

View File

@ -1440,7 +1440,6 @@ func TestContextAbortWithError(t *testing.T) {
func TestContextClientIP(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder())
c.Request, _ = http.NewRequest("POST", "/", nil)
c.engine.trustedCIDRs, _ = c.engine.prepareTrustedCIDRs()
resetContextForClientIPTests(c)
// Legacy tests (validating that the defaults don't break the

66
gin.go
View File

@ -387,33 +387,22 @@ func (engine *Engine) Run(addr ...string) (err error) {
return
}
func (engine *Engine) prepareTrustedCIDRs() ([]*net.IPNet, error) {
if engine.trustedProxies == nil {
func prepareTrustedCIDRs(trustedProxies []string) ([]*net.IPNet, error) {
if trustedProxies == nil {
return nil, nil
}
cidr := make([]*net.IPNet, 0, len(engine.trustedProxies))
for _, trustedProxy := range engine.trustedProxies {
if !strings.Contains(trustedProxy, "/") {
ip := parseIP(trustedProxy)
if ip == nil {
return cidr, &net.ParseError{Type: "IP address", Text: trustedProxy}
}
cidrArr := make([]*net.IPNet, 0, len(trustedProxies))
switch len(ip) {
case net.IPv4len:
trustedProxy += "/32"
case net.IPv6len:
trustedProxy += "/128"
}
}
_, cidrNet, err := net.ParseCIDR(trustedProxy)
for _, trustedProxy := range trustedProxies {
cidrNet, err := prepareCIDR(trustedProxy)
if err != nil {
return cidr, err
return cidrArr, err
}
cidr = append(cidr, cidrNet)
cidrArr = append(cidrArr, cidrNet)
}
return cidr, nil
return cidrArr, nil
}
// SetTrustedProxies set a list of network origins (IPv4 addresses,
@ -426,7 +415,35 @@ func (engine *Engine) prepareTrustedCIDRs() ([]*net.IPNet, error) {
// return the remote address directly.
func (engine *Engine) SetTrustedProxies(trustedProxies []string) error {
engine.trustedProxies = trustedProxies
return engine.parseTrustedProxies()
trustedCIDRs, err := prepareTrustedCIDRs(trustedProxies)
engine.trustedCIDRs = trustedCIDRs
return err
}
// converts a string CIDR or single IP to a IPNet instance
func prepareCIDR(ipOrCidr string) (*net.IPNet, error) {
// not a CIDR, try to convert to cidr notation
if !strings.Contains(ipOrCidr, "/") {
ip := parseIP(ipOrCidr)
if ip == nil {
return nil, &net.ParseError{Type: "IP address", Text: ipOrCidr}
}
switch len(ip) {
case net.IPv4len:
ipOrCidr += "/32"
case net.IPv6len:
ipOrCidr += "/128"
}
}
_, cidrNet, err := net.ParseCIDR(ipOrCidr)
if err != nil {
return nil, err
}
return cidrNet, err
}
// isUnsafeTrustedProxies checks if Engine.trustedCIDRs contains all IPs, it's not safe if it has (returns true)
@ -434,13 +451,6 @@ func (engine *Engine) isUnsafeTrustedProxies() bool {
return engine.isTrustedProxy(net.ParseIP("0.0.0.0")) || engine.isTrustedProxy(net.ParseIP("::"))
}
// parseTrustedProxies parse Engine.trustedProxies to Engine.trustedCIDRs
func (engine *Engine) parseTrustedProxies() error {
trustedCIDRs, err := engine.prepareTrustedCIDRs()
engine.trustedCIDRs = trustedCIDRs
return err
}
// isTrustedProxy will check whether the IP address is included in the trusted list according to Engine.trustedCIDRs
func (engine *Engine) isTrustedProxy(ip net.IP) bool {
if engine.trustedCIDRs == nil {

View File

@ -183,10 +183,7 @@ func TestLoggerWithConfigFormatting(t *testing.T) {
var gotParam LogFormatterParams
var gotKeys map[string]any
buffer := new(strings.Builder)
router := New()
router.engine.trustedCIDRs, _ = router.engine.prepareTrustedCIDRs()
router.Use(LoggerWithConfig(LoggerConfig{
Output: buffer,
Formatter: func(param LogFormatterParams) string {