From c22d088aadf69ff49e52104dbffa47a13c905f99 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Sat, 8 Sep 2018 13:40:49 -0600 Subject: [PATCH] Updated content-type parsing to respect wildcards. --- context.go | 2 +- context_test.go | 26 ++++++++++++++++++++------ utils.go | 20 ++++++++++++++++++++ 3 files changed, 41 insertions(+), 7 deletions(-) diff --git a/context.go b/context.go index 063c72f0..5ace552d 100644 --- a/context.go +++ b/context.go @@ -893,7 +893,7 @@ func (c *Context) NegotiateFormat(offered ...string) string { } for _, accepted := range c.Accepted { for _, offert := range offered { - if accepted == offert { + if contentMatches(accepted, offert) { return offert } } diff --git a/context_test.go b/context_test.go index 782f7bed..1832ea8c 100644 --- a/context_test.go +++ b/context_test.go @@ -1122,13 +1122,27 @@ func TestContextNegotiationFormat(t *testing.T) { } func TestContextNegotiationFormatWithAccept(t *testing.T) { - c, _ := CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("POST", "/", nil) - c.Request.Header.Add("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8") + typicalAccept := "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8" + testCases := []struct { + accept string + offer []string + expect string + }{ + {"*/*", []string{MIMEJSON}, MIMEJSON}, + {"", []string{MIMEJSON, MIMEXML}, MIMEJSON}, + {MIMEJSON, []string{MIMEXML}, ""}, + {MIMEJSON, []string{MIMEXML, MIMEJSON, MIMEHTML}, MIMEJSON}, + {typicalAccept, []string{MIMEJSON, MIMEXML}, MIMEXML}, + {typicalAccept, []string{MIMEXML, MIMEHTML}, MIMEHTML}, + {typicalAccept, []string{MIMEJSON}, MIMEJSON}, + } - assert.Equal(t, MIMEXML, c.NegotiateFormat(MIMEJSON, MIMEXML)) - assert.Equal(t, MIMEHTML, c.NegotiateFormat(MIMEXML, MIMEHTML)) - assert.Empty(t, c.NegotiateFormat(MIMEJSON)) + for i, tc := range testCases { + c, _ := CreateTestContext(httptest.NewRecorder()) + c.Request, _ = http.NewRequest("POST", "/", nil) + c.Request.Header.Add("Accept", tc.accept) + assert.Equal(t, tc.expect, c.NegotiateFormat(tc.offer...), fmt.Sprintf("test case %d", i)) + } } func TestContextNegotiationFormatCustum(t *testing.T) { diff --git a/utils.go b/utils.go index bf32c775..c1a36cd8 100644 --- a/utils.go +++ b/utils.go @@ -110,6 +110,26 @@ func parseAccept(acceptHeader string) []string { return out } +// contentMatches returns true if content matches the pattern. +// See RFC2616, section 14.1. +func contentMatches(pattern, content string) bool { + patternParts := strings.SplitN(pattern, "/", 2) + contentParts := strings.SplitN(content, "/", 2) + if patternParts[0] == "*" { + return true + } + if patternParts[0] != contentParts[0] { + return false + } + if patternParts[1] == "*" { + return true + } + if patternParts[1] != contentParts[1] { + return false + } + return true +} + func lastChar(str string) uint8 { if str == "" { panic("The length of the string can't be 0")