diff --git a/binding/binding_test.go b/binding/binding_test.go index a9f8b9e3..cdf1c23c 100644 --- a/binding/binding_test.go +++ b/binding/binding_test.go @@ -1403,6 +1403,23 @@ func TestPlainBinding(t *testing.T) { require.NoError(t, p.Bind(req, ptr)) } +func TestPlainBindingBindBody(t *testing.T) { + p := Plain + + var s string + require.NoError(t, p.BindBody([]byte("test body"), &s)) + assert.Equal(t, "test body", s) + + var bs []byte + require.NoError(t, p.BindBody([]byte("test bytes"), &bs)) + assert.Equal(t, []byte("test bytes"), bs) + + var i int + require.Error(t, p.BindBody([]byte("test"), &i)) + + require.NoError(t, p.BindBody([]byte("test"), nil)) +} + func testProtoBodyBindingFail(t *testing.T, b Binding, name, path, badPath, body, badBody string) { assert.Equal(t, name, b.Name()) diff --git a/codec/json/json_test.go b/codec/json/json_test.go new file mode 100644 index 00000000..e21dbbb7 --- /dev/null +++ b/codec/json/json_test.go @@ -0,0 +1,53 @@ +// Copyright 2025 Gin Core Team. All rights reserved. +// Use of this source code is governed by a MIT style +// license that can be found in the LICENSE file. + +package json + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestJSONMarshal(t *testing.T) { + data := map[string]string{"key": "value"} + result, err := API.Marshal(data) + require.NoError(t, err) + assert.JSONEq(t, `{"key":"value"}`, string(result)) +} + +func TestJSONUnmarshal(t *testing.T) { + var data map[string]string + err := API.Unmarshal([]byte(`{"key":"value"}`), &data) + require.NoError(t, err) + assert.Equal(t, "value", data["key"]) +} + +func TestJSONMarshalIndent(t *testing.T) { + data := map[string]string{"key": "value"} + result, err := API.MarshalIndent(data, "", " ") + require.NoError(t, err) + assert.Contains(t, string(result), `"key": "value"`) +} + +func TestJSONNewEncoder(t *testing.T) { + var buf bytes.Buffer + encoder := API.NewEncoder(&buf) + require.NotNil(t, encoder) + err := encoder.Encode(map[string]string{"key": "value"}) + require.NoError(t, err) + assert.JSONEq(t, `{"key":"value"}`, buf.String()) +} + +func TestJSONNewDecoder(t *testing.T) { + buf := bytes.NewBufferString(`{"key":"value"}`) + decoder := API.NewDecoder(buf) + require.NotNil(t, decoder) + var data map[string]string + err := decoder.Decode(&data) + require.NoError(t, err) + assert.Equal(t, "value", data["key"]) +} diff --git a/context_test.go b/context_test.go index 41694585..5b53a6c9 100644 --- a/context_test.go +++ b/context_test.go @@ -2947,6 +2947,17 @@ func TestContextGetRawData(t *testing.T) { assert.Equal(t, "Fetch binary post data", string(data)) } +func TestContextGetRawDataNilBody(t *testing.T) { + c, _ := CreateTestContext(httptest.NewRecorder()) + c.Request, _ = http.NewRequest(http.MethodPost, "/", nil) + c.Request.Body = nil + + data, err := c.GetRawData() + require.Error(t, err) + assert.Nil(t, data) + assert.Equal(t, "cannot read nil body", err.Error()) +} + func TestContextRenderDataFromReader(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) @@ -3535,6 +3546,24 @@ func TestContextSetCookieData(t *testing.T) { setCookie := c.Writer.Header().Get("Set-Cookie") assert.Contains(t, setCookie, "SameSite=None") }) + + // Test that SameSiteDefaultMode is replaced with context's SameSite + t.Run("SameSiteDefaultMode is replaced with context SameSite", func(t *testing.T) { + c, _ := CreateTestContext(httptest.NewRecorder()) + c.SetSameSite(http.SameSiteLaxMode) + cookie := &http.Cookie{ + Name: "user", + Value: "gin", + Path: "/", + Domain: "localhost", + Secure: true, + HttpOnly: true, + SameSite: http.SameSiteDefaultMode, + } + c.SetCookieData(cookie) + setCookie := c.Writer.Header().Get("Set-Cookie") + assert.Contains(t, setCookie, "user=gin") + }) } func TestGetMapFromFormData(t *testing.T) { diff --git a/gin.go b/gin.go index 2e033bf3..bd8a824c 100644 --- a/gin.go +++ b/gin.go @@ -96,6 +96,9 @@ type Engine struct { // (used for routing HTTP requests) happens only once, even if called multiple times concurrently. routeTreesUpdated sync.Once + // mu protects the route trees during concurrent access. + mu sync.RWMutex + // RedirectTrailingSlash enables automatic redirection if the current route can't be matched but a // handler for the path with (without) the trailing slash exists. // For example if /foo/ is requested but a route only exists for /foo, the @@ -368,6 +371,9 @@ func (engine *Engine) addRoute(method, path string, handlers HandlersChain) { debugPrintRoute(method, path, handlers) + engine.mu.Lock() + defer engine.mu.Unlock() + root := engine.trees.get(method) if root == nil { root = new(node) @@ -388,6 +394,8 @@ func (engine *Engine) addRoute(method, path string, handlers HandlersChain) { // Routes returns a slice of registered routes, including some useful information, such as: // the http method, path, and the handler name. func (engine *Engine) Routes() (routes RoutesInfo) { + engine.mu.RLock() + defer engine.mu.RUnlock() for _, tree := range engine.trees { routes = iterate("", tree.method, routes, tree.root) } diff --git a/ginS/gins_test.go b/ginS/gins_test.go index ffde85d2..1821e780 100644 --- a/ginS/gins_test.go +++ b/ginS/gins_test.go @@ -244,3 +244,30 @@ func TestStaticFS(t *testing.T) { assert.Equal(t, http.StatusOK, w.Code) } + +func TestLoadHTMLGlob(t *testing.T) { + LoadHTMLGlob("../testdata/template/*.tmpl") +} + +func TestLoadHTMLFiles(t *testing.T) { + LoadHTMLFiles("../testdata/template/hello.tmpl") +} + +func TestLoadHTMLFS(t *testing.T) { + LoadHTMLFS(http.Dir("../testdata/template"), "hello.tmpl") +} + +func TestRunInvalidAddress(t *testing.T) { + err := Run("invalid:address:format") + assert.Error(t, err) +} + +func TestRunTLSInvalid(t *testing.T) { + err := RunTLS("invalid:address:format", "nonexistent.crt", "nonexistent.key") + assert.Error(t, err) +} + +func TestRunUnixInvalid(t *testing.T) { + err := RunUnix("/nonexistent/path/socket.sock") + assert.Error(t, err) +} diff --git a/gin_test.go b/gin_test.go index 43c9494d..27639858 100644 --- a/gin_test.go +++ b/gin_test.go @@ -15,6 +15,7 @@ import ( "reflect" "strconv" "strings" + "sync" "sync/atomic" "testing" "time" @@ -1084,3 +1085,69 @@ func TestUpdateRouteTreesCalledOnce(t *testing.T) { assert.Equal(t, "ok", w.Body.String()) } } + +func TestConcurrentAddRouteAndRoutes(t *testing.T) { + router := New() + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(2) + go func(i int) { + defer wg.Done() + router.GET(fmt.Sprintf("/route%d", i), func(c *Context) { + c.String(http.StatusOK, "ok") + }) + }(i) + go func(i int) { + defer wg.Done() + _ = router.Routes() + }(i) + } + wg.Wait() + assert.Len(t, router.Routes(), 100) +} + +func TestServeErrorWritten(t *testing.T) { + router := New() + router.GET("/", func(c *Context) { + c.Status(http.StatusInternalServerError) + _, _ = c.Writer.Write([]byte("custom error")) + }) + router.NoRoute(func(c *Context) { + c.Status(http.StatusNotFound) + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/notfound", nil) + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusNotFound, w.Code) +} + +func TestServeErrorStatusMismatch(t *testing.T) { + router := New() + router.HandleMethodNotAllowed = true + router.NoMethod(func(c *Context) { + c.Status(http.StatusForbidden) + }) + + router.GET("/exists", func(c *Context) { + c.String(http.StatusOK, "ok") + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodPost, "/exists", nil) + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusForbidden, w.Code) +} + +func TestServeErrorMessageWrite(t *testing.T) { + router := New() + router.NoRoute(func(c *Context) { + c.Next() + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/notfound", nil) + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusNotFound, w.Code) + assert.Equal(t, "text/plain", w.Header().Get("Content-Type")) +} diff --git a/logger_test.go b/logger_test.go index 53d0df95..2d22c6e1 100644 --- a/logger_test.go +++ b/logger_test.go @@ -329,6 +329,7 @@ func TestColorForLatency(t *testing.T) { assert.Equal(t, white, colorForLantency(time.Millisecond*20), "20ms should be white") assert.Equal(t, green, colorForLantency(time.Millisecond*150), "150ms should be green") assert.Equal(t, cyan, colorForLantency(time.Millisecond*250), "250ms should be cyan") + assert.Equal(t, blue, colorForLantency(time.Millisecond*400), "400ms should be blue") assert.Equal(t, yellow, colorForLantency(time.Millisecond*600), "600ms should be yellow") assert.Equal(t, magenta, colorForLantency(time.Millisecond*1500), "1.5s should be magenta") assert.Equal(t, red, colorForLantency(time.Second*3), "other things should be red") diff --git a/utils_test.go b/utils_test.go index 893ebc88..ee91bddd 100644 --- a/utils_test.go +++ b/utils_test.go @@ -13,6 +13,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func init() { @@ -145,6 +146,17 @@ func TestMarshalXMLforH(t *testing.T) { assert.Error(t, e) } +func TestMarshalXMLforHSuccess(t *testing.T) { + h := H{ + "key": "value", + "number": 42, + } + data, err := xml.Marshal(h) + require.NoError(t, err) + assert.Contains(t, string(data), "value") + assert.Contains(t, string(data), "42") +} + func TestIsASCII(t *testing.T) { assert.True(t, isASCII("test")) assert.False(t, isASCII("๐Ÿงก๐Ÿ’›๐Ÿ’š๐Ÿ’™๐Ÿ’œ"))