예제 #1
0
파일: types.go 프로젝트: rs/rest-layer
func getSubResourceResolver(r *resource.Resource) graphql.FieldResolveFn {
	return func(p graphql.ResolveParams) (interface{}, error) {
		parent, ok := p.Source.(map[string]interface{})
		if !ok {
			return nil, nil
		}
		lookup, page, perPage, err := listParamResolver(r, p, nil)
		if err != nil {
			return nil, err
		}
		// Limit the connection to parent's owned
		lookup.AddQuery(schema.Query{
			schema.Equal{
				Field: r.ParentField(),
				Value: parent["id"],
			},
		})
		list, err := r.Find(p.Context, lookup, page, perPage)
		if err != nil {
			return nil, err
		}
		result := make([]map[string]interface{}, len(list.Items))
		for i, item := range list.Items {
			result[i] = item.Payload
		}
		return result, nil
	}
}
예제 #2
0
파일: types.go 프로젝트: rs/rest-layer
// addConnections adds connections fields to the object afterward to prevent from dead loops
func (t types) addConnections(o *graphql.Object, idx resource.Index, r *resource.Resource) {
	// Add sub field references
	for name, def := range r.Schema().Fields {
		if ref, ok := def.Validator.(*schema.Reference); ok {
			sr, found := idx.GetResource(ref.Path, nil)
			if !found {
				log.Panicf("resource reference not found: %s", ref.Path)
			}
			o.AddFieldConfig(name, &graphql.Field{
				Description: def.Description,
				Type:        t.getObjectType(idx, sr),
				Args:        getFArgs(def.Params),
				Resolve:     getSubFieldResolver(name, sr, def),
			})
		}
	}
	// Add sub resources
	for _, sr := range r.GetResources() {
		name := sr.Name()
		o.AddFieldConfig(name, &graphql.Field{
			Description: fmt.Sprintf("Connection to %s", name),
			Type:        graphql.NewList(t.getObjectType(idx, sr)),
			Args:        listArgs,
			Resolve:     getSubResourceResolver(sr),
		})
	}
}
예제 #3
0
파일: main.go 프로젝트: rs/rest-layer
// NewBasicAuthHandler handles basic HTTP auth against the provided user resource
func NewBasicAuthHandler(users *resource.Resource) func(next http.Handler) http.Handler {
	return func(next http.Handler) http.Handler {
		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
			if u, p, ok := r.BasicAuth(); ok {
				// Lookup the user by its id
				ctx := r.Context()
				user, err := users.Get(ctx, u)
				if user != nil && err == resource.ErrUnauthorized {
					// Ignore unauthorized errors set by ourselves
					err = nil
				}
				if err != nil {
					// If user resource storage handler returned an error, respond with an error
					if err == resource.ErrNotFound {
						http.Error(w, "Invalid credential", http.StatusForbidden)
					} else {
						http.Error(w, err.Error(), http.StatusInternalServerError)
					}
					return
				}
				if schema.VerifyPassword(user.Payload["password"], []byte(p)) {
					// Store the auth user into the context for later use
					r = r.WithContext(NewContextWithUser(ctx, user))
					next.ServeHTTP(w, r)
					return
				}
			}
			// Stop the middleware chain and return a 401 HTTP error
			w.Header().Set("WWW-Authenticate", `Basic realm="API"`)
			http.Error(w, "Please provide proper credentials", http.StatusUnauthorized)
		})
	}
}
예제 #4
0
파일: main.go 프로젝트: rs/rest-layer
// NewJWTHandler parse and validates JWT token if present and store it in the net/context
func NewJWTHandler(users *resource.Resource, jwtKeyFunc jwt.Keyfunc) func(next http.Handler) http.Handler {
	return func(next http.Handler) http.Handler {
		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
			token, err := request.ParseFromRequest(r, request.OAuth2Extractor, jwtKeyFunc)
			if err == request.ErrNoTokenInRequest {
				// If no token is found, let REST Layer hooks decide if the resource is public or not
				next.ServeHTTP(w, r)
				return
			}
			if err != nil || !token.Valid {
				// Here you may want to return JSON error
				http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
				return
			}
			claims := token.Claims.(jwt.MapClaims)
			userID, ok := claims["user_id"].(string)
			if !ok || userID == "" {
				// The provided token is malformed, user_id claim is missing
				http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
				return
			}
			// Lookup the user by its id
			ctx := r.Context()
			user, err := users.Get(ctx, userID)
			if user != nil && err == resource.ErrUnauthorized {
				// Ignore unauthorized errors set by ourselves (see AuthResourceHook)
				err = nil
			}
			if err != nil {
				// If user resource storage handler returned an error, respond with an error
				if err == resource.ErrNotFound {
					http.Error(w, "Invalid credential", http.StatusForbidden)
				} else {
					http.Error(w, err.Error(), http.StatusInternalServerError)
				}
				return
			}
			// Store it into the request's context
			ctx = NewContextWithUser(ctx, user)
			r = r.WithContext(ctx)
			// If xlog is setup, store the user as logger field
			xlog.FromContext(ctx).SetField("user_id", user.ID)
			next.ServeHTTP(w, r)
		})
	}
}
예제 #5
0
파일: query.go 프로젝트: rs/rest-layer
func (t types) getGetQuery(idx resource.Index, r *resource.Resource) *graphql.Field {
	return &graphql.Field{
		Description: fmt.Sprintf("Get %s by id", r.Name()),
		Type:        t.getObjectType(idx, r),
		Args: graphql.FieldConfigArgument{
			"id": &graphql.ArgumentConfig{
				Type: graphql.String,
			},
		},
		Resolve: func(p graphql.ResolveParams) (interface{}, error) {
			id, ok := p.Args["id"].(string)
			if !ok {
				return nil, nil
			}
			item, err := r.Get(p.Context, id)
			if err != nil {
				return nil, err
			}
			return item.Payload, nil
		},
	}
}
예제 #6
0
파일: query.go 프로젝트: rs/rest-layer
func (t types) getListQuery(idx resource.Index, r *resource.Resource, params url.Values) *graphql.Field {
	return &graphql.Field{
		Description: fmt.Sprintf("Get a list of %s", r.Name()),
		Type:        graphql.NewList(t.getObjectType(idx, r)),
		Args:        listArgs,
		Resolve: func(p graphql.ResolveParams) (interface{}, error) {
			lookup, page, perPage, err := listParamResolver(r, p, params)
			if err != nil {
				return nil, err
			}
			list, err := r.Find(p.Context, lookup, page, perPage)
			if err != nil {
				return nil, err
			}
			result := make([]map[string]interface{}, len(list.Items))
			for i, item := range list.Items {
				result[i] = item.Payload
			}
			return result, nil
		},
	}
}
예제 #7
0
파일: types.go 프로젝트: rs/rest-layer
func getSubFieldResolver(parentField string, r *resource.Resource, f schema.Field) graphql.FieldResolveFn {
	s, serialize := f.Validator.(schema.FieldSerializer)
	return func(p graphql.ResolveParams) (data interface{}, err error) {
		parent, ok := p.Source.(map[string]interface{})
		if !ok {
			return nil, nil
		}
		var item *resource.Item
		// Get sub field resource
		item, err = r.Get(p.Context, parent[parentField])
		if err != nil {
			return nil, err
		}
		data = item.Payload
		if f.Handler != nil {
			data, err = f.Handler(p.Context, data, p.Args)
		}
		if err == nil && serialize {
			data, err = s.Serialize(data)
		}
		return data, err
	}
}
예제 #8
0
파일: query.go 프로젝트: rs/rest-layer
func listParamResolver(r *resource.Resource, p graphql.ResolveParams, params url.Values) (lookup *resource.Lookup, page int, perPage int, err error) {
	page = 1
	// Default value on non HEAD request for perPage is -1 (pagination disabled)
	perPage = -1
	if l := r.Conf().PaginationDefaultLimit; l > 0 {
		perPage = l
	}
	if p, ok := p.Args["page"].(string); ok && p != "" {
		i, err := strconv.ParseUint(p, 10, 32)
		if err != nil {
			return nil, 0, 0, errors.New("invalid `limit` parameter")
		}
		page = int(i)
	}
	if l, ok := p.Args["limit"].(string); ok && l != "" {
		i, err := strconv.ParseUint(l, 10, 32)
		if err != nil {
			return nil, 0, 0, errors.New("invalid `limit` parameter")
		}
		perPage = int(i)
	}
	if perPage == -1 && page != 1 {
		return nil, 0, 0, errors.New("cannot use `page' parameter with no `limit' paramter on a resource with no default pagination size")
	}
	lookup = resource.NewLookup()
	if sort, ok := p.Args["sort"].(string); ok && sort != "" {
		if err := lookup.SetSort(sort, r.Validator()); err != nil {
			return nil, 0, 0, fmt.Errorf("invalid `sort` parameter: %v", err)
		}
	}
	if filter, ok := p.Args["filter"].(string); ok && filter != "" {
		if err := lookup.AddFilter(filter, r.Validator()); err != nil {
			return nil, 0, 0, fmt.Errorf("invalid `filter` parameter: %v", err)
		}
	}
	if params != nil {
		if filter := params.Get("filter"); filter != "" {
			if err := lookup.AddFilter(filter, r.Validator()); err != nil {
				return nil, 0, 0, fmt.Errorf("invalid `filter` parameter: %v", err)
			}
		}
	}
	return
}
예제 #9
0
파일: types.go 프로젝트: rs/rest-layer
// getObjectType returns a graphql object type definition from a REST layer schema
func (t types) getObjectType(idx resource.Index, r *resource.Resource) *graphql.Object {
	// Memoize types by their name so we don't create several instance of the same resource
	name := r.Name()
	o := t[name]
	if o == nil {
		o = graphql.NewObject(graphql.ObjectConfig{
			Name:        name,
			Description: r.Schema().Description,
			Fields:      getFields(idx, r.Schema()),
		})
		t[name] = o
		t.addConnections(o, idx, r)
	}
	return o
}