diff --git a/gin.go b/gin.go index c901bf96..2252bb9e 100644 --- a/gin.go +++ b/gin.go @@ -13,7 +13,6 @@ import ( "path" "strings" "sync" - "unicode" "github.com/gin-gonic/gin/internal/bytesconv" filesystem "github.com/gin-gonic/gin/internal/fs" @@ -747,7 +746,7 @@ func redirectTrailingSlash(c *Context) { p := req.URL.Path if prefix := path.Clean(c.Request.Header.Get("X-Forwarded-Prefix")); prefix != "." { prefix = sanitizePathChars(prefix) - prefix = removeRepeatedSlash(prefix) + prefix = removeRepeatedChar(prefix, '/') p = prefix + "/" + req.URL.Path } @@ -759,38 +758,16 @@ func redirectTrailingSlash(c *Context) { } // sanitizePathChars removes unsafe characters from path strings, -// keeping only letters, numbers, forward slashes, and hyphens. +// keeping only ASCII letters, ASCII numbers, forward slashes, and hyphens. func sanitizePathChars(s string) string { return strings.Map(func(r rune) rune { - if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '/' || r == '-' { + if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '/' || r == '-' { return r } return -1 }, s) } -// removeRepeatedSlash removes consecutive forward slashes from a string, -// replacing sequences of multiple slashes with a single slash. -func removeRepeatedSlash(s string) string { - if !strings.Contains(s, "//") { - return s - } - - var sb strings.Builder - sb.Grow(len(s) - 1) - prevChar := rune(0) - - for _, r := range s { - if r == '/' && prevChar == '/' { - continue - } - sb.WriteRune(r) - prevChar = r - } - - return sb.String() -} - func redirectFixedPath(c *Context, root *node, trailingSlash bool) bool { req := c.Request rPath := req.URL.Path diff --git a/gin_test.go b/gin_test.go index 245fc408..be076537 100644 --- a/gin_test.go +++ b/gin_test.go @@ -913,34 +913,3 @@ func TestMethodNotAllowedNoRoute(t *testing.T) { assert.NotPanics(t, func() { g.ServeHTTP(resp, req) }) assert.Equal(t, http.StatusNotFound, resp.Code) } - -func TestRemoveRepeatedSlash(t *testing.T) { - testCases := []struct { - name string - str string - want string - }{ - { - name: "noSlash", - str: "abc", - want: "abc", - }, - { - name: "withSlash", - str: "/a/b/c/", - want: "/a/b/c/", - }, - { - name: "withRepeatedSlash", - str: "/a//b///c////", - want: "/a/b/c/", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - res := removeRepeatedSlash(tc.str) - assert.Equal(t, tc.want, res) - }) - } -} diff --git a/path.go b/path.go index 82438c13..3b67caa9 100644 --- a/path.go +++ b/path.go @@ -5,6 +5,8 @@ package gin +const stackBufSize = 128 + // cleanPath is the URL version of path.Clean, it returns a canonical URL path // for p, eliminating . and .. elements. // @@ -19,7 +21,6 @@ package gin // // If the result of this process is an empty string, "/" is returned. func cleanPath(p string) string { - const stackBufSize = 128 // Turn empty string into "/" if p == "" { return "/" @@ -148,3 +149,55 @@ func bufApp(buf *[]byte, s string, w int, c byte) { } b[w] = c } + +// removeRepeatedChar removes multiple consecutive 'char's from a string. +// if s == "/a//b///c////" && char == '/', it returns "/a/b/c/" +func removeRepeatedChar(s string, char byte) string { + // Check if there are any consecutive chars + hasRepeatedChar := false + for i := 1; i < len(s); i++ { + if s[i] == char && s[i-1] == char { + hasRepeatedChar = true + break + } + } + if !hasRepeatedChar { + return s + } + + // Reasonably sized buffer on stack to avoid allocations in the common case. + buf := make([]byte, 0, stackBufSize) + + // Invariants: + // reading from s; r is index of next byte to process. + // writing to buf; w is index of next byte to write. + r := 0 + w := 0 + + for n := len(s); r < n; { + if s[r] == char { + // Write the first char + bufApp(&buf, s, w, char) + w++ + r++ + + // Skip all consecutive chars + for r < n && s[r] == char { + r++ + } + } else { + // Copy non-char character + bufApp(&buf, s, w, s[r]) + w++ + r++ + } + } + + // If the original string was not modified (or only shortened at the end), + // return the respective substring of the original string. + // Otherwise, return a new string from the buffer. + if len(buf) == 0 { + return s[:w] + } + return string(buf[:w]) +} diff --git a/path_test.go b/path_test.go index 2269b78e..cc8a6f04 100644 --- a/path_test.go +++ b/path_test.go @@ -143,3 +143,50 @@ func BenchmarkPathCleanLong(b *testing.B) { } } } + +func TestRemoveRepeatedChar(t *testing.T) { + testCases := []struct { + name string + str string + char byte + want string + }{ + { + name: "empty", + str: "", + char: 'a', + want: "", + }, + { + name: "noSlash", + str: "abc", + char: ',', + want: "abc", + }, + { + name: "withSlash", + str: "/a/b/c/", + char: '/', + want: "/a/b/c/", + }, + { + name: "withRepeatedSlashes", + str: "/a//b///c////", + char: '/', + want: "/a/b/c/", + }, + { + name: "threeSlashes", + str: "///", + char: '/', + want: "/", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + res := removeRepeatedChar(tc.str, tc.char) + assert.Equal(t, tc.want, res) + }) + } +}