func TestValidateContentType(t *testing.T) { data := []struct { hdr string allowed []string err *errors.Validation }{ {"application/json", []string{"application/json"}, nil}, {"application/json", []string{"application/x-yaml", "text/html"}, errors.InvalidContentType("application/json", []string{"application/x-yaml", "text/html"})}, {"text/html; charset=utf-8", []string{"text/html"}, nil}, {"text/html;charset=utf-8", []string{"text/html"}, nil}, {"", []string{"application/json"}, errors.InvalidContentType("", []string{"application/json"})}, {"text/html; charset=utf-8", []string{"application/json"}, errors.InvalidContentType("text/html; charset=utf-8", []string{"application/json"})}, {"application(", []string{"application/json"}, errors.InvalidContentType("application(", []string{"application/json"})}, {"application/json;char*", []string{"application/json"}, errors.InvalidContentType("application/json;char*", []string{"application/json"})}, } for _, v := range data { err := validateContentType(v.allowed, v.hdr) if v.err == nil { assert.NoError(t, err, "input: %q", v.hdr) } else { assert.Error(t, err, "input: %q", v.hdr) assert.IsType(t, &errors.Validation{}, err, "input: %q", v.hdr) assert.Equal(t, v.err.Error(), err.Error(), "input: %q", v.hdr) assert.EqualValues(t, http.StatusUnsupportedMediaType, err.Code()) } } }
// ContentType validates the content type of a request func validateContentType(allowed []string, actual string) *errors.Validation { mt, _, err := mime.ParseMediaType(actual) if err != nil { return errors.InvalidContentType(actual, allowed) } if swag.ContainsStringsCI(allowed, mt) { return nil } return errors.InvalidContentType(actual, allowed) }
func (p *untypedParamBinder) Bind(request *http.Request, routeParams RouteParams, consumer httpkit.Consumer, target reflect.Value) error { // fmt.Println("binding", p.name, "as", p.Type()) switch p.parameter.In { case "query": data, custom, err := p.readValue(request.URL.Query(), target) if err != nil { return err } if custom { return nil } return p.bindValue(data, target) case "header": data, custom, err := p.readValue(request.Header, target) if err != nil { return err } if custom { return nil } return p.bindValue(data, target) case "path": data, custom, err := p.readValue(routeParams, target) if err != nil { return err } if custom { return nil } return p.bindValue(data, target) case "formData": var err error var mt string mt, _, e := httpkit.ContentType(request.Header) if e != nil { // because of the interface conversion go thinks the error is not nil // so we first check for nil and then set the err var if it's not nil err = e } if err != nil { return errors.InvalidContentType("", []string{"multipart/form-data", "application/x-www-form-urlencoded"}) } if mt != "multipart/form-data" && mt != "application/x-www-form-urlencoded" { return errors.InvalidContentType(mt, []string{"multipart/form-data", "application/x-www-form-urlencoded"}) } if mt == "multipart/form-data" { if err := request.ParseMultipartForm(defaultMaxMemory); err != nil { return errors.NewParseError(p.Name, p.parameter.In, "", err) } } if err := request.ParseForm(); err != nil { return errors.NewParseError(p.Name, p.parameter.In, "", err) } if p.parameter.Type == "file" { file, header, err := request.FormFile(p.parameter.Name) if err != nil { return errors.NewParseError(p.Name, p.parameter.In, "", err) } target.Set(reflect.ValueOf(httpkit.File{Data: file, Header: header})) return nil } if request.MultipartForm != nil { data, custom, err := p.readValue(url.Values(request.MultipartForm.Value), target) if err != nil { return err } if custom { return nil } return p.bindValue(data, target) } data, custom, err := p.readValue(url.Values(request.PostForm), target) if err != nil { return err } if custom { return nil } return p.bindValue(data, target) case "body": newValue := reflect.New(target.Type()) if err := consumer.Consume(request.Body, newValue.Interface()); err != nil { if err == io.EOF && p.parameter.Default != nil { target.Set(reflect.ValueOf(p.parameter.Default)) return nil } tpe := p.parameter.Type if p.parameter.Format != "" { tpe = p.parameter.Format } return errors.InvalidType(p.Name, p.parameter.In, tpe, nil) } target.Set(reflect.Indirect(newValue)) return nil default: return errors.New(500, fmt.Sprintf("invalid parameter location %q", p.parameter.In)) } }