Compare commits

...

3 Commits

Author SHA1 Message Date
Milad
ffbac3b065
Merge 38c24307a96a49e25391b9fe12d6c803ac2c0a82 into 9914178584e42458ff7d23891463a880f58c9d86 2026-01-02 16:18:45 +08:00
Nurysso
9914178584
fix(context): ClientIP handling for multiple X-Forwarded-For header values (#4472)
* Fix ClientIP calculation by concatenating all RemoteIPHeaders values

* test: used http.MethodGet instead constants and fix lints

* lint error fixed

* Refactor ClientIP X-Forwarded-For tests

---------

Co-authored-by: Bo-Yi Wu <appleboy.tw@gmail.com>
2026-01-02 10:15:27 +08:00
Miladev95
38c24307a9 Add ensureParamsCapacity helper and doc comment in tree.go 2025-12-08 13:23:55 +03:30
3 changed files with 50 additions and 11 deletions

View File

@ -989,7 +989,8 @@ func (c *Context) ClientIP() string {
if trusted && c.engine.ForwardedByClientIP && c.engine.RemoteIPHeaders != nil {
for _, headerName := range c.engine.RemoteIPHeaders {
ip, valid := c.engine.validateHeader(c.requestHeader(headerName))
headerValue := strings.Join(c.Request.Header.Values(headerName), ",")
ip, valid := c.engine.validateHeader(headerValue)
if valid {
return ip
}

View File

@ -1143,6 +1143,37 @@ func TestContextRenderNoContentIndentedJSON(t *testing.T) {
assert.Equal(t, "application/json; charset=utf-8", w.Header().Get("Content-Type"))
}
func TestContextClientIPWithMultipleHeaders(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder())
c.Request, _ = http.NewRequest(http.MethodGet, "/test", nil)
// Multiple X-Forwarded-For headers
c.Request.Header.Add("X-Forwarded-For", "1.2.3.4, "+localhostIP)
c.Request.Header.Add("X-Forwarded-For", "5.6.7.8")
c.Request.RemoteAddr = localhostIP + ":1234"
c.engine.ForwardedByClientIP = true
c.engine.RemoteIPHeaders = []string{"X-Forwarded-For"}
_ = c.engine.SetTrustedProxies([]string{localhostIP})
// Should return 5.6.7.8 (last non-trusted IP)
assert.Equal(t, "5.6.7.8", c.ClientIP())
}
func TestContextClientIPWithSingleHeader(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder())
c.Request, _ = http.NewRequest(http.MethodGet, "/test", nil)
c.Request.Header.Set("X-Forwarded-For", "1.2.3.4, "+localhostIP)
c.Request.RemoteAddr = localhostIP + ":1234"
c.engine.ForwardedByClientIP = true
c.engine.RemoteIPHeaders = []string{"X-Forwarded-For"}
_ = c.engine.SetTrustedProxies([]string{localhostIP})
// Should return 1.2.3.4
assert.Equal(t, "1.2.3.4", c.ClientIP())
}
// Tests that the response is serialized as Secure JSON
// and Content-Type is set to application/json
func TestContextRenderSecureJSON(t *testing.T) {

27
tree.go
View File

@ -58,6 +58,8 @@ func (trees methodTrees) get(method string) *node {
return nil
}
// longestCommonPrefix returns the length in bytes of the longest common prefix
// of the two input strings `a` and `b`.
func longestCommonPrefix(a, b string) int {
i := 0
max_ := min(len(a), len(b))
@ -410,6 +412,19 @@ type skippedNode struct {
paramsCount int16
}
// ensureParamsCapacity ensures that the params slice has capacity for at least needed elements.
// It preserves existing length and content.
func ensureParamsCapacity(params *Params, needed int) {
if params == nil {
return
}
if cap(*params) < needed {
newParams := make(Params, len(*params), needed)
copy(newParams, *params)
*params = newParams
}
}
// Returns the handle registered with the given path (key). The values of
// wildcards are saved to a map.
// If no handle can be found, a TSR (trailing slash redirect) recommendation is
@ -497,11 +512,7 @@ walk: // Outer loop for walking the tree
// Save param value
if params != nil {
// Preallocate capacity if necessary
if cap(*params) < int(globalParamsCount) {
newParams := make(Params, len(*params), globalParamsCount)
copy(newParams, *params)
*params = newParams
}
ensureParamsCapacity(params, int(globalParamsCount))
if value.params == nil {
value.params = params
@ -550,11 +561,7 @@ walk: // Outer loop for walking the tree
// Save param value
if params != nil {
// Preallocate capacity if necessary
if cap(*params) < int(globalParamsCount) {
newParams := make(Params, len(*params), globalParamsCount)
copy(newParams, *params)
*params = newParams
}
ensureParamsCapacity(params, int(globalParamsCount))
if value.params == nil {
value.params = params