From e62a35745b992db2dec18316e0a1ab682cdd8731 Mon Sep 17 00:00:00 2001 From: wei840222 Date: Thu, 24 Jun 2021 11:41:19 +0800 Subject: [PATCH] add test case --- context_test.go | 61 ++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 50 insertions(+), 11 deletions(-) diff --git a/context_test.go b/context_test.go index 91228458..2a4d2185 100644 --- a/context_test.go +++ b/context_test.go @@ -2059,15 +2059,54 @@ func TestRemoteIPFail(t *testing.T) { } func TestContextWithFallbackValueFromRequestContext(t *testing.T) { - var key struct{} - c := &Context{} - c.Request, _ = http.NewRequest("POST", "/", nil) - c.Request = c.Request.WithContext(context.WithValue(context.TODO(), key, "value")) - - assert.Equal(t, "value", c.Value(key)) - - c2 := &Context{} - c2.Request, _ = http.NewRequest("POST", "/", nil) - c2.Request = c2.Request.WithContext(context.WithValue(context.TODO(), "key", "value2")) - assert.Equal(t, "value2", c2.Value("key")) + tests := []struct { + name string + getContextAndKey func() (*Context, interface{}) + value interface{} + }{ + { + name: "c with struct context key", + getContextAndKey: func() (*Context, interface{}) { + var key struct{} + c := &Context{} + c.Request, _ = http.NewRequest("POST", "/", nil) + c.Request = c.Request.WithContext(context.WithValue(context.TODO(), key, "value")) + return c, key + }, + value: "value", + }, + { + name: "c with string context key", + getContextAndKey: func() (*Context, interface{}) { + c := &Context{} + c.Request, _ = http.NewRequest("POST", "/", nil) + c.Request = c.Request.WithContext(context.WithValue(context.TODO(), "key", "value")) + return c, "key" + }, + value: "value", + }, + { + name: "c with nil http.Request", + getContextAndKey: func() (*Context, interface{}) { + c := &Context{} + return c, "key" + }, + value: nil, + }, + { + name: "c with nil http.Request.Context()", + getContextAndKey: func() (*Context, interface{}) { + c := &Context{} + c.Request, _ = http.NewRequest("POST", "/", nil) + return c, "key" + }, + value: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, key := tt.getContextAndKey() + assert.Equal(t, tt.value, c.Value(key)) + }) + } }