diff --git a/context.go b/context.go index 5174033e..3850667d 100644 --- a/context.go +++ b/context.go @@ -1047,6 +1047,33 @@ func (c *Context) IsWebsocket() bool { return false } +// Scheme returns the HTTP scheme of the request ("http" or "https"). +// When running behind reverse proxies or load balancers `Request.URL.Scheme` is usually empty. +// the original scheme is commonly forwarded via headers such as X-Forwarded-Proto. +// Reference: +// https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/X-Forwarded-Proto +func (c *Context) Scheme() string { + if c.Request.TLS != nil { + return "https" + } + if scheme := c.requestHeader("X-Forwarded-Proto"); scheme != "" { + return scheme + } + if scheme := c.requestHeader("X-Forwarded-Protocol"); scheme != "" { + return scheme + } + if ssl := c.requestHeader("X-Forwarded-Ssl"); ssl == "on" { + return "https" + } + if scheme := c.requestHeader("X-Url-Scheme"); scheme != "" { + return scheme + } + if scheme := c.Request.URL.Scheme; scheme != "" { + return scheme + } + return "http" +} + func (c *Context) requestHeader(key string) string { return c.Request.Header.Get(key) } diff --git a/context_test.go b/context_test.go index ef60379d..364a92ae 100644 --- a/context_test.go +++ b/context_test.go @@ -7,6 +7,7 @@ package gin import ( "bytes" "context" + "crypto/tls" "errors" "fmt" "html/template" @@ -2955,6 +2956,65 @@ func TestWebsocketsRequired(t *testing.T) { assert.False(t, c.IsWebsocket()) } +func TestContextScheme(t *testing.T) { + // TLS connection takes highest priority. + c, _ := CreateTestContext(httptest.NewRecorder()) + c.Request, _ = http.NewRequest(http.MethodGet, "/", nil) + c.Request.TLS = &tls.ConnectionState{} + assert.Equal(t, "https", c.Scheme()) + + // X-Forwarded-Proto header. + c, _ = CreateTestContext(httptest.NewRecorder()) + c.Request, _ = http.NewRequest(http.MethodGet, "/", nil) + c.Request.Header.Set("X-Forwarded-Proto", "https") + assert.Equal(t, "https", c.Scheme()) + + c, _ = CreateTestContext(httptest.NewRecorder()) + c.Request, _ = http.NewRequest(http.MethodGet, "/", nil) + c.Request.Header.Set("X-Forwarded-Proto", "http") + assert.Equal(t, "http", c.Scheme()) + + // X-Forwarded-Protocol header. + c, _ = CreateTestContext(httptest.NewRecorder()) + c.Request, _ = http.NewRequest(http.MethodGet, "/", nil) + c.Request.Header.Set("X-Forwarded-Protocol", "https") + assert.Equal(t, "https", c.Scheme()) + + // X-Forwarded-Ssl: on header. + c, _ = CreateTestContext(httptest.NewRecorder()) + c.Request, _ = http.NewRequest(http.MethodGet, "/", nil) + c.Request.Header.Set("X-Forwarded-Ssl", "on") + assert.Equal(t, "https", c.Scheme()) + + c, _ = CreateTestContext(httptest.NewRecorder()) + c.Request, _ = http.NewRequest(http.MethodGet, "/", nil) + c.Request.Header.Set("X-Forwarded-Ssl", "off") + assert.Equal(t, "http", c.Scheme()) + + // X-Url-Scheme header. + c, _ = CreateTestContext(httptest.NewRecorder()) + c.Request, _ = http.NewRequest(http.MethodGet, "/", nil) + c.Request.Header.Set("X-Url-Scheme", "https") + assert.Equal(t, "https", c.Scheme()) + + // Request.URL.Scheme fallback. + c, _ = CreateTestContext(httptest.NewRecorder()) + c.Request, _ = http.NewRequest(http.MethodGet, "https://example.com/", nil) + assert.Equal(t, "https", c.Scheme()) + + // Default fallback: plain http. + c, _ = CreateTestContext(httptest.NewRecorder()) + c.Request, _ = http.NewRequest(http.MethodGet, "/", nil) + assert.Equal(t, "http", c.Scheme()) + + // TLS takes priority over X-Forwarded-Proto. + c, _ = CreateTestContext(httptest.NewRecorder()) + c.Request, _ = http.NewRequest(http.MethodGet, "/", nil) + c.Request.TLS = &tls.ConnectionState{} + c.Request.Header.Set("X-Forwarded-Proto", "http") + assert.Equal(t, "https", c.Scheme()) +} + func TestGetRequestHeaderValue(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) c.Request, _ = http.NewRequest(http.MethodGet, "/chat", nil)