From c2268830431c0d13fb2c0d8ec46de6f375e31e1f Mon Sep 17 00:00:00 2001 From: Yuki Igarashi Date: Tue, 7 Apr 2020 23:39:29 +0900 Subject: [PATCH] Handle invalid Accept header --- context.go | 28 ++++++++++++++++++---------- context_test.go | 26 ++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 10 deletions(-) diff --git a/context.go b/context.go index a2384c0e..30b7891b 100644 --- a/context.go +++ b/context.go @@ -1047,19 +1047,27 @@ func (c *Context) NegotiateFormat(offered ...string) string { return offered[0] } for _, accepted := range c.Accepted { + // According to RFC 2616, media-range = ( "*/*" | ( type "/" "*" ) | ( type "/" subtype ) ) + acceptedMediaRange := strings.Split(accepted, "/") + if len(acceptedMediaRange) != 2 || (acceptedMediaRange[0] == "*" && acceptedMediaRange[1] != "*") { + continue + } + for _, offer := range offered { // According to RFC 2616 and RFC 2396, non-ASCII characters are not allowed in headers, - // therefore we can just iterate over the string without casting it into []rune - i := 0 - for ; i < len(accepted); i++ { - if accepted[i] == '*' || offer[i] == '*' { - return offer - } - if accepted[i] != offer[i] { - break - } + // therefore we can handle the string without casting it into []rune + + offerMediaRange := strings.Split(offer, "/") + if len(offerMediaRange) != 2 { + continue } - if i == len(accepted) { + + if acceptedMediaRange[0] == "*" || offerMediaRange[0] == "*" { + return offer + } + + if acceptedMediaRange[0] == offerMediaRange[0] && + (acceptedMediaRange[1] == "*" || offerMediaRange[1] == "*" || acceptedMediaRange[1] == offerMediaRange[1]) { return offer } } diff --git a/context_test.go b/context_test.go index ce077bc6..f46091fa 100644 --- a/context_test.go +++ b/context_test.go @@ -1228,6 +1228,32 @@ func TestContextNegotiationFormatWithWildcardAccept(t *testing.T) { assert.Equal(t, c.NegotiateFormat(MIMEHTML), MIMEHTML) } +func TestContextNegotiationFormatWithInvalidAccept(t *testing.T) { + c, _ := CreateTestContext(httptest.NewRecorder()) + c.Request, _ = http.NewRequest("POST", "/", nil) + c.Request.Header.Add("Accept", "text*,*/html") + + assert.Equal(t, c.NegotiateFormat("*/*"), "") + assert.Equal(t, c.NegotiateFormat("text/*"), "") + assert.Equal(t, c.NegotiateFormat("text/html"), "") + assert.Equal(t, c.NegotiateFormat(MIMEJSON), "") + assert.Equal(t, c.NegotiateFormat(MIMEXML), "") + assert.Equal(t, c.NegotiateFormat(MIMEHTML), "") + + c, _ = CreateTestContext(httptest.NewRecorder()) + c.Request, _ = http.NewRequest("POST", "/", nil) + c.Request.Header.Add("Accept", "text*/*,application/j*,application/jso,application/json2") + + assert.Equal(t, c.NegotiateFormat("*/*"), "*/*") + assert.Equal(t, c.NegotiateFormat("text/*"), "") + assert.Equal(t, c.NegotiateFormat("text/html"), "") + assert.Equal(t, c.NegotiateFormat("application/*"), "application/*") + assert.Equal(t, c.NegotiateFormat("application/json"), "") + assert.Equal(t, c.NegotiateFormat(MIMEJSON), "") + assert.Equal(t, c.NegotiateFormat(MIMEXML), "") + assert.Equal(t, c.NegotiateFormat(MIMEHTML), "") +} + func TestContextNegotiationFormatCustom(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) c.Request, _ = http.NewRequest("POST", "/", nil)