diff --git a/.gitignore b/.gitignore index bdd50c95..acbf2f67 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +*~ vendor/* !vendor/vendor.json coverage.out diff --git a/binding/binding.go b/binding/binding.go index 6d58c3cd..63902763 100644 --- a/binding/binding.go +++ b/binding/binding.go @@ -79,6 +79,7 @@ var ( YAML = yamlBinding{} Uri = uriBinding{} Header = headerBinding{} + Plain = plainBinding{} ) // Default returns the appropriate Binding instance based on the HTTP method @@ -101,6 +102,8 @@ func Default(method, contentType string) Binding { return YAML case MIMEMultipartPOSTForm: return FormMultipart + case MIMEPlain: + return Plain default: // case MIMEPOSTForm: return Form } diff --git a/binding/binding_test.go b/binding/binding_test.go index 806f3ac9..73a268f0 100644 --- a/binding/binding_test.go +++ b/binding/binding_test.go @@ -156,6 +156,9 @@ func TestBindingDefault(t *testing.T) { assert.Equal(t, FormMultipart, Default("POST", MIMEMultipartPOSTForm)) assert.Equal(t, FormMultipart, Default("PUT", MIMEMultipartPOSTForm)) + assert.Equal(t, Plain, Default("POST", MIMEPlain)) + assert.Equal(t, Plain, Default("PUT", MIMEPlain)) + assert.Equal(t, ProtoBuf, Default("POST", MIMEPROTOBUF)) assert.Equal(t, ProtoBuf, Default("PUT", MIMEPROTOBUF)) @@ -680,6 +683,46 @@ func TestExistsFails(t *testing.T) { assert.Error(t, err) } +type failRead struct{} + +func (f *failRead) Read(b []byte) (n int, err error) { + return 0, errors.New("my fail") +} + +func (f *failRead) Close() error { + return nil +} + +func TestPlainBinding(t *testing.T) { + p := Plain + assert.Equal(t, "plain", p.Name()) + + var s string + req := requestWithBody("POST", "/", "test string") + assert.NoError(t, p.Bind(req, &s)) + assert.Equal(t, s, "test string") + + var bs []byte + req = requestWithBody("POST", "/", "test []byte") + assert.NoError(t, p.Bind(req, &bs)) + assert.Equal(t, bs, []byte("test []byte")) + + var i int + req = requestWithBody("POST", "/", "test fail") + assert.Error(t, p.Bind(req, &i)) + + req = requestWithBody("POST", "/", "") + req.Body = &failRead{} + assert.Error(t, p.Bind(req, &s)) + + req = requestWithBody("POST", "/", "") + assert.Nil(t, p.Bind(req, nil)) + + var ptr *string + req = requestWithBody("POST", "/", "") + assert.Nil(t, p.Bind(req, ptr)) +} + func TestHeaderBinding(t *testing.T) { h := Header assert.Equal(t, "header", h.Name()) diff --git a/binding/plain.go b/binding/plain.go new file mode 100644 index 00000000..e87e5903 --- /dev/null +++ b/binding/plain.go @@ -0,0 +1,47 @@ +package binding + +import ( + "fmt" + "io/ioutil" + "net/http" + "reflect" + "unsafe" +) + +type plainBinding struct{} + +func (plainBinding) Name() string { + return "plain" +} + +func (plainBinding) Bind(req *http.Request, obj interface{}) error { + if obj == nil { + return nil + } + + v := reflect.ValueOf(obj) + + for v.Kind() == reflect.Ptr { + if v.IsNil() { + return nil + } + v = v.Elem() + } + + all, err := ioutil.ReadAll(req.Body) + if err != nil { + return err + } + + if v.Kind() == reflect.String { + v.SetString(*(*string)(unsafe.Pointer(&all))) + return nil + } + + if _, ok := v.Interface().([]byte); ok { + v.SetBytes(all) + return nil + } + + return fmt.Errorf("type (%T) unkown type", v) +} diff --git a/context.go b/context.go index d9fcc285..60afc70c 100644 --- a/context.go +++ b/context.go @@ -583,6 +583,11 @@ func (c *Context) BindYAML(obj interface{}) error { return c.MustBindWith(obj, binding.YAML) } +// BindHeader is a shortcut for c.MustBindWith(obj, binding.Plain). +func (c *Context) BindPlain(obj interface{}) error { + return c.MustBindWith(obj, binding.Plain) +} + // BindHeader is a shortcut for c.MustBindWith(obj, binding.Header). func (c *Context) BindHeader(obj interface{}) error { return c.MustBindWith(obj, binding.Header) @@ -642,6 +647,11 @@ func (c *Context) ShouldBindYAML(obj interface{}) error { return c.ShouldBindWith(obj, binding.YAML) } +// ShouldBindPlain is a shortcut for c.ShouldBindWith(obj, binding.Header). +func (c *Context) ShouldBindPlain(obj interface{}) error { + return c.ShouldBindWith(obj, binding.Plain) +} + // ShouldBindHeader is a shortcut for c.ShouldBindWith(obj, binding.Header). func (c *Context) ShouldBindHeader(obj interface{}) error { return c.ShouldBindWith(obj, binding.Header) diff --git a/context_test.go b/context_test.go index f7bb0f51..50221e18 100644 --- a/context_test.go +++ b/context_test.go @@ -1436,6 +1436,30 @@ func TestContextBindWithXML(t *testing.T) { assert.Equal(t, 0, w.Body.Len()) } +func TestContextBindPlain(t *testing.T) { + w := httptest.NewRecorder() + c, _ := CreateTestContext(w) + c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString(`test string`)) + c.Request.Header.Add("Content-Type", MIMEPlain) + + var s string + + assert.NoError(t, c.BindPlain(&s)) + assert.Equal(t, "test string", s) + assert.Equal(t, 0, w.Body.Len()) + // ======================== + + c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString(`test []byte`)) + c.Request.Header.Add("Content-Type", MIMEPlain) + + var bs []byte + + assert.NoError(t, c.BindPlain(&bs)) + assert.Equal(t, []byte("test []byte"), bs) + assert.Equal(t, 0, w.Body.Len()) + +} + func TestContextBindHeader(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w) @@ -1565,6 +1589,29 @@ func TestContextShouldBindWithXML(t *testing.T) { assert.Equal(t, 0, w.Body.Len()) } +func TestContextShouldBindPlain(t *testing.T) { + w := httptest.NewRecorder() + c, _ := CreateTestContext(w) + c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString(`test string`)) + c.Request.Header.Add("Content-Type", MIMEPlain) + + var s string + + assert.NoError(t, c.ShouldBindPlain(&s)) + assert.Equal(t, "test string", s) + assert.Equal(t, 0, w.Body.Len()) + // ======================== + + c.Request, _ = http.NewRequest("POST", "/", bytes.NewBufferString(`test []byte`)) + c.Request.Header.Add("Content-Type", MIMEPlain) + + var bs []byte + + assert.NoError(t, c.BindPlain(&bs)) + assert.Equal(t, []byte("test []byte"), bs) + assert.Equal(t, 0, w.Body.Len()) + +} func TestContextShouldBindHeader(t *testing.T) { w := httptest.NewRecorder() c, _ := CreateTestContext(w)