Example #1
0
File: storage.go Project: gourd/kit
// LoadAuthorize looks up AuthorizeData by a code.
// Client information MUST be loaded together.
// Optionally can return error if expired.
func (storage *Storage) LoadAuthorize(code string) (d *osin.AuthorizeData, err error) {

	// TODO: use logger := log.NewContext(,sg)
	logger, errLogger := msg, errMsg
	logger.Log(
		"method", "LoadAuthorize",
		"code", code)

	// loading osin using osin storage
	srv, err := store.Get(storage.ctx, KeyAuth)
	if err != nil {
		return
	}
	defer srv.Close()

	e := &AuthorizeData{}
	conds := store.NewConds()
	conds.Add("code", code)

	err = srv.One(conds, e)
	if err != nil {
		return
	} else if e == nil {
		err = store.Error(http.StatusNotFound,
			"AuthorizeData not found for the code")
		return
	}

	// load client here
	var ok bool
	cli, err := storage.GetClient(e.ClientID)
	if err != nil {
		return
	} else if e.Client, ok = cli.(*Client); !ok {
		err = store.Error(http.StatusInternalServerError,
			"Internal Server Error")

		errLogger.Log(
			"method", "GetClient",
			"code", code,
			"cond", conds,
			"raw client", fmt.Sprintf("%#v", cli),
			"message", "Unable to cast raw client into Client")
		return
	}

	// load user data here
	if e.UserID != "" {
		userStore, err := store.Get(storage.ctx, KeyUser)
		if err != nil {
			return d, err
		}
		user := &User{}
		userStore.One(store.NewConds().Add("id", e.UserID), user)
		e.UserData = user
	}

	d = e.ToOsin()
	return
}
Example #2
0
// creates dummy client and user directly from the stores
func createStoreDummies(ctx context.Context, password, redirect string) (*oauth2.Client, *oauth2.User) {

	// generate dummy user
	us, err := store.Get(ctx, oauth2.KeyUser)
	if err != nil {
		panic(err)
	}
	u := dummyNewUser(password)
	err = us.Create(store.NewConds(), u)
	if err != nil {
		panic(err)
	}

	// get related dummy client
	cs, err := store.Get(ctx, oauth2.KeyClient)
	if err != nil {
		panic(err)
	}
	c := dummyNewClient(redirect)
	c.UserID = u.ID
	err = cs.Create(store.NewConds(), c)
	if err != nil {
		panic(err)
	}

	return c, u
}
Example #3
0
// NewUserFunc creates the default parser of login HTTP request
func NewUserFunc(idName string) UserFunc {
	return func(r *http.Request, us store.Store) (ou OAuth2User, err error) {

		var c store.Conds

		id := r.Form.Get(idName)

		if id == "" {
			serr := store.Error(http.StatusBadRequest, "empty user identifier")
			err = serr
			return
		}

		// different condition based on the user_id field format
		if govalidator.IsEmail(id) {
			c = store.NewConds().Add("email", id)
		} else {
			c = store.NewConds().Add("username", id)
		}

		// get user from database
		u := us.AllocEntity()
		err = us.One(c, u)

		if err != nil {
			serr := store.ExpandError(err)
			if serr.Status != http.StatusNotFound {
				serr.TellServer("Error searching user %#v: %s", id, serr.ServerMsg)
				return
			}
			err = serr
			return
		}

		// if user does not exists
		if u == nil {
			serr := store.Error(http.StatusBadRequest, "Username or Password incorrect")
			serr.TellServer("Unknown user %#v attempt to login", id)
			err = serr
			return
		}

		// cast the user as OAuth2User
		// and do password check
		ou, ok := u.(OAuth2User)
		if !ok {
			serr := store.Error(http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError))
			serr.TellServer("User cannot be cast as OAuth2User")
			err = serr
			return
		}

		return
	}
}
Example #4
0
func TestConds_branching(t *testing.T) {

	var err error

	fn := "./test3.tmp"

	// two branch query
	cond1 := store.NewConds().
		Add("HelloWorld =", "foo bar").
		Add("FooBar !=", "hello world")
	cond2 := store.NewConds().
		Add("HelloWorld =", "foo bar 2").
		Add("FooBar !=", "hello world")

	q := store.NewQuery().
		AddCond("", cond1).
		AddCond("", cond2)

	q.GetConds().SetRel(store.Or)

	// test source
	source := upperio.NewSource(testUpperDb(fn))

	// add dummy data to the database
	if err := testUpperDbData(source); err != nil {
		t.Fatal(err.Error())
	}

	// connect to database again
	conn, err := source.Open()
	if err != nil {
		t.Error(err.Error())
	}
	defer conn.Close()

	// query connection
	sess := conn.Raw().(db.Database)
	coll, err := sess.Collection("dummy_data")
	res := coll.Find(upperio.Conds(q.GetConds()))
	var tds []testData
	res.All(&tds)

	expLen := 2
	if l := len(tds); l != expLen {
		t.Errorf("result set size expected: %d, got: %d\ntest data set:\t%#v",
			expLen, l, tds)
	}

	// clean up the temp database
	err = os.Remove(fn)
	if err != nil {
		t.Error(err.Error())
	}

}
Example #5
0
File: storage.go Project: gourd/kit
// SaveAuthorize saves authorize data.
func (storage *Storage) SaveAuthorize(d *osin.AuthorizeData) (err error) {

	// TODO: use logger := log.NewContext(,sg)
	logger := msg
	logger.Log(
		"method", "SaveAuthorize",
		"*osin.AuthorizeData", d)

	srv, err := store.Get(storage.ctx, KeyAuth)
	if err != nil {
		return
	}
	defer srv.Close()

	e := &AuthorizeData{}
	err = e.ReadOsin(d)
	if err != nil {
		return
	}

	// store client id with auth in database
	e.ClientID = e.Client.GetId()

	// create the auth data now
	err = srv.Create(store.NewConds(), e)
	return
}
Example #6
0
File: storage.go Project: gourd/kit
// LoadRefresh retrieves refresh AccessData. Client information MUST be loaded together.
// AuthorizeData and AccessData DON'T NEED to be loaded if not easily available.
// Optionally can return error if expired.
func (storage *Storage) LoadRefresh(token string) (d *osin.AccessData, err error) {

	// TODO: use logger := log.NewContext(,sg)
	logger := msg
	logger.Log(
		"method", "LoadRefresh",
		"token", token)

	srv, err := store.Get(storage.ctx, KeyAccess)
	if err != nil {
		return
	}
	defer srv.Close()

	e := &AccessData{}
	conds := store.NewConds()
	conds.Add("refresh_token", token)

	err = srv.One(conds, e)
	if err != nil {
		return
	} else if e == nil {
		err = store.Error(http.StatusNotFound,
			"AccessData not found for the refresh token")
		return
	}

	// load supplementary data
	if err = storage.loadAccessSupp(e); err != nil {
		return
	}

	d = e.ToOsin()
	return
}
Example #7
0
func TestBasicConds_AddGetMapErr(t *testing.T) {
	t.Parallel()
	c := store.NewConds().Add("foo", "bar").Add("foo", "again")
	_, err := c.GetMap()

	if err == nil {
		t.Errorf("Failed to return error with conflicting map conditions")
	}
}
Example #8
0
func TestBasicConds_SetGetRel(t *testing.T) {
	t.Parallel()
	c := store.NewConds().Add("foo", "bar").Add("hello", "world")
	if c.GetRel() != store.And {
		t.Errorf("Conds Rel flag is not initialized as And")
	} else {
		t.Log("Conds Rel initialized as And")
	}

	c.SetRel(store.Or)
	if c.GetRel() != store.Or {
		t.Errorf("Failed to set Conds Rel to Or")
	} else {
		t.Log("Conds Rel changed to Or")
	}
}
Example #9
0
func TestBasicConds_AddGetAll(t *testing.T) {
	t.Parallel()
	c := store.NewConds().Add("foo", "bar").Add("hello", "world")
	a := c.GetAll()

	if a[0].Prop != "foo" {
		t.Errorf("Failed to add: %#v", a[0])
	} else if a[0].Value != "bar" {
		t.Errorf("Failed testing value with original string: %#v -> %s", a[0].Value, a[0].Value)
	}

	if a[1].Prop != "hello" {
		t.Errorf("Failed to add: %#v", a[1])
	} else if a[1].Value != "world" {
		t.Errorf("Failed testing value with original string: %#v -> %s", a[1].Value, a[1].Value)
	}
}
Example #10
0
File: storage.go Project: gourd/kit
// GetClient implements osin.Storage.GetClient
func (storage *Storage) GetClient(id string) (c osin.Client, err error) {

	// TODO: use logger := log.NewContext(,sg)
	logger, errLogger := msg, errMsg
	logger.Log(
		"method", "GetClient",
		"id", id)

	srv, err := store.Get(storage.ctx, KeyClient)
	if err != nil {
		serr := store.Error(http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError))
		serr.TellServer("unable to get client store: %s", err)
		err = serr
		return
	}
	defer srv.Close()

	e := &Client{}
	conds := store.NewConds()
	conds.Add("id", id)

	err = srv.One(conds, e)
	if err != nil {
		serr := store.ExpandError(err)
		errLogger.Log(
			"method", "GetClient",
			"id", id,
			"cond", conds,
			"message", "Failed running One()",
			"error", serr.ServerMsg)
		return
	} else if e == nil {
		errLogger.Log(
			"method", "GetClient",
			"id", id,
			"cond", fmt.Sprintf("%#v", conds),
			"message", "Client not found")
		err = store.Error(http.StatusNotFound,
			"Client not found for the given id")
		return
	}

	c = e
	return
}
Example #11
0
func TestBasicConds_AddGetMap(t *testing.T) {
	t.Parallel()
	c := store.NewConds().Add("foo", "bar").Add("hello", "world")
	m, err := c.GetMap()

	if err != nil {
		t.Errorf("Error in GetMap(): %s", err.Error())
	}

	if m["foo"] != "bar" {
		t.Errorf("Failed to get proper map: m[\"foo\"] is \"%#v\" instead of \"%s\"",
			m["foo"])
	}
	if m["hello"] != "world" {
		t.Errorf("Failed to get proper map: m[\"hello\"] is \"%#v\" instead of \"%s\"",
			m["hello"])
	}
}
Example #12
0
File: storage.go Project: gourd/kit
// loadAccessSupp loads supplementary data onto an *AccessData
func (storage *Storage) loadAccessSupp(e *AccessData) (err error) {

	// load client here
	var ok bool
	cli, err := storage.GetClient(e.ClientID)
	if err != nil {
		return
	} else if e.Client, ok = cli.(*Client); !ok {
		serr := store.Error(http.StatusInternalServerError,
			"Internal Server Error")
		serr.TellServer("Unable to cast client into Client type: %#v", cli)
		err = serr
		return
	}
	e.ClientID = e.Client.GetId()

	// unserialize previous AuthorizeData here
	if e.AuthorizeDataJSON != "" {
		ad := &AuthorizeData{}
		json.Unmarshal([]byte(e.AuthorizeDataJSON), ad)
		e.AuthorizeData = ad
	}

	// unserialize previous AccessData here
	if e.AccessDataJSON != "" {
		ad := &AccessData{}
		json.Unmarshal([]byte(e.AccessDataJSON), ad)
		e.AccessData = ad
	}

	// load user data here
	if e.UserID != "" {
		userStore, err := store.Get(storage.ctx, KeyUser)
		if err != nil {
			return err
		}
		user := &User{}
		userStore.One(store.NewConds().Add("id", e.UserID), user)
		e.UserData = user
	}

	return

}
Example #13
0
File: storage.go Project: gourd/kit
// RemoveRefresh revokes or deletes refresh AccessData.
func (storage *Storage) RemoveRefresh(token string) (err error) {

	// TODO: use logger := log.NewContext(,sg)
	logger := msg
	logger.Log(
		"method", "RemoveRefresh",
		"token", token)

	srv, err := store.Get(storage.ctx, KeyAccess)
	if err != nil {
		return
	}
	defer srv.Close()

	conds := store.NewConds()
	conds.Add("refresh_token", token)
	err = srv.Delete(conds)
	return
}
Example #14
0
File: storage.go Project: gourd/kit
// RemoveAuthorize revokes or deletes the authorization code.
func (storage *Storage) RemoveAuthorize(code string) (err error) {

	// TODO: use logger := log.NewContext(,sg)
	logger := msg
	logger.Log(
		"method", "RemoveAuthorize",
		"code", code)

	srv, err := store.Get(storage.ctx, KeyAuth)
	if err != nil {
		return
	}
	defer srv.Close()

	conds := store.NewConds()
	conds.Add("code", code)
	err = srv.Delete(conds)
	return
}
Example #15
0
func TestNewConds(t *testing.T) {
	t.Parallel()
	var c store.Conds
	c = store.NewConds()
	t.Logf("NewCond can return Conds: %#v", c)
}
Example #16
0
File: storage.go Project: gourd/kit
// SaveAccess writes AccessData.
// If RefreshToken is not blank, it must save in a way that can be loaded using LoadRefresh.
func (storage *Storage) SaveAccess(ad *osin.AccessData) (err error) {

	// TODO: use logger := log.NewContext(,sg)
	logger, errLogger := msg, errMsg
	logger.Log(
		"method", "SaveAccess",
		"*osin.AccessData", ad)

	srv, err := store.Get(storage.ctx, KeyAccess)
	if err != nil {
		return
	}
	defer srv.Close()

	// generate database access type
	e := &AccessData{}
	err = e.ReadOsin(ad)
	if err != nil {
		return
	}

	// store client id with access in database
	e.ClientID = e.Client.GetId()

	// if AuthorizeData is set, store as JSON
	if ad.AuthorizeData != nil {
		var b []byte
		authData := &AuthorizeData{}
		if err = authData.ReadOsin(ad.AuthorizeData); err != nil {
			return
		}
		if b, err = json.Marshal(authData); err != nil {
			return
		}
		e.AuthorizeDataJSON = string(b)
	}

	// if AccessData is set, store as JSON
	if ad.AccessData != nil {
		var b []byte
		accessData := &AccessData{}
		if err = accessData.ReadOsin(ad.AccessData); err != nil {
			return
		}
		if accessData.AccessData != nil {
			// forget data of too long ago
			accessData.AccessData = nil
		}
		if b, err = json.Marshal(accessData); err != nil {
			return
		}
		e.AccessDataJSON = string(b)
	}

	// create in database
	if err = srv.Create(store.NewConds(), e); err != nil {
		serr := store.ExpandError(err)
		errLogger.Log(
			"method", "SaveAccess",
			"*osin.AccessData", ad,
			"err", serr.ServerMsg)
	}
	return
}
Example #17
0
func UserStoreServices(paths httpservice.Paths, endpoints map[string]endpoint.Endpoint) (handlers httpservice.Services) {

	// variables to use later
	noun := paths.Noun()
	storeKey := KeyUser
	getStore := func(ctx context.Context) (s *UserStore, err error) {
		raw, err := store.Get(ctx, storeKey)
		if err != nil {
			return
		}

		s, ok := raw.(*UserStore)
		if !ok {
			err = fmt.Errorf(`store.Get(KeyUser) does not return *KeyUser`)
			return
		}
		return
	}

	// define default middlewares
	var prepareCreate endpoint.Middleware = func(inner endpoint.Endpoint) endpoint.Endpoint {
		return func(ctx context.Context, request interface{}) (respond interface{}, err error) {
			// placeholder: anything you want to do with the entity
			//              before append to database
			httpservice.EnforceCreate(request)
			return inner(ctx, request)
		}
	}

	var prepareUpdate endpoint.Middleware = func(inner endpoint.Endpoint) endpoint.Endpoint {
		return func(ctx context.Context, request interface{}) (response interface{}, err error) {

			sReq := request.(*httpservice.Request)

			// get context information
			r := gourdctx.HTTPRequest(ctx)
			if r == nil {
				serr := store.ErrorInternal
				serr.ServerMsg = "missing request in context"
				err = serr
				return
			}

			el := &[]User{}
			q := sReq.Query

			// get store
			s, err := getStore(ctx)
			if err != nil {
				serr := store.ErrorInternal
				serr.ServerMsg = fmt.Sprintf("error obtaining %s store (%s)", storeKey, err)
				err = serr
				return
			}
			defer s.Close()

			// find the previous content of the id
			err = s.Search(q).All(el)
			if err != nil {
				serr := store.ErrorInternal
				serr.ServerMsg = fmt.Sprintf("error searching %s: %s",
					noun.Singular(), err)
				err = serr
				return
			}

			// tell the inner
			if len(*el) > 0 {
				sReq.Previous = &(*el)[0]
			}

			// enforce agreement on sReq.Payload with previous sReq.Entity
			httpservice.EnforceUpdate(sReq.Previous, sReq.Payload)

			// placeholder: anything you want to do with the entity
			//              before update to database
			return inner(ctx, sReq)
		}
	}

	var prepareList endpoint.Middleware = func(inner endpoint.Endpoint) endpoint.Endpoint {
		return func(ctx context.Context, request interface{}) (response interface{}, err error) {
			response, err = inner(ctx, request)
			if err != nil {
				return
			}

			vmap := response.(map[string]interface{})
			list := vmap[noun.Plural()].(*[]User)
			if list == nil || *list == nil {
				*list = make([]User, 0)
			}
			vmap[noun.Plural()] = list

			// placeholder: anything you want to do with the entity
			//              list response
			return vmap, nil
		}
	}

	// wrap inner response with default protocol
	var prepareProtocol endpoint.Middleware = func(inner endpoint.Endpoint) endpoint.Endpoint {
		return func(ctx context.Context, request interface{}) (response interface{}, err error) {

			v, err := inner(ctx, request)
			if err != nil {
				return
			}

			switch v.(type) {
			case map[string]interface{}:
				response = store.ExpandResponse(v.(map[string]interface{}))
			default:
				response = store.NewResponse(noun.Plural(), v)
			}

			return
		}
	}

	// generates response permission checker middleware
	checkPermBefore := func(permission string) endpoint.Middleware {
		return func(inner endpoint.Endpoint) endpoint.Endpoint {
			return func(ctx context.Context, request interface{}) (response interface{}, err error) {
				m := perm.GetMux(ctx)
				err = m.Allow(ctx, permission, request)
				if err != nil {
					return
				}
				return inner(ctx, request)
			}
		}
	}

	// generates request permission checker middleware
	checkPermAfter := func(permission string) endpoint.Middleware {
		return func(inner endpoint.Endpoint) endpoint.Endpoint {
			return func(ctx context.Context, request interface{}) (response interface{}, err error) {

				v, err := inner(ctx, request)
				if err != nil {
					return
				}

				m := perm.GetMux(ctx)
				err = m.Allow(ctx, permission, request, v)
				if err != nil {
					return
				}

				response = v
				return

			}
		}
	}

	//
	// ==== raw decode functions
	//

	decodeServiceIDReq := func(ctx context.Context, r *http.Request) (request *httpservice.Request, err error) {
		id := r.URL.Query().Get(":id") // will change
		cond := store.NewConds().Add("id", id)
		request = &httpservice.Request{
			Request: r,
			Query:   store.NewQuery().SetConds(cond),
		}
		return
	}

	decodeJSONEntity := func(ctx context.Context, r *http.Request) (entity *User, err error) {
		// allocate entity
		entity = &User{}

		// decode request
		dec := json.NewDecoder(r.Body)
		err = dec.Decode(entity)
		return
	}

	//
	// ==== httptransport.DecodeRequestFunc implementations
	//

	// decodeIDReq generically decoded :id field
	// (works with pat based URL routing, router specific)
	var decodeIDReq httptransport.DecodeRequestFunc = func(ctx context.Context, r *http.Request) (request interface{}, err error) {
		return decodeServiceIDReq(ctx, r)
	}

	// decodeListReq decode query for list endpoint
	var decodeListReq httptransport.DecodeRequestFunc = func(ctx context.Context, r *http.Request) (request interface{}, err error) {

		sReq := &httpservice.Request{
			Request: r,
			Query:   store.NewQuery(),
		}

		// parse sort parameter
		sortStr := r.FormValue("sorts")
		if sortStr != "" {
			sorts := strings.Split(sortStr, ",")
			for _, sort := range sorts {
				sReq.Query.Sort(sort)
			}
		}

		// parse paging request parameter
		offset, limit := func(r *http.Request) (o, l uint64) {
			ostr := r.FormValue("offset")
			lstr := r.FormValue("limit")
			if ostr != "" {
				if ot, err := strconv.ParseUint(ostr, 10, 64); err == nil {
					o = ot
				}
			}
			if lstr != "" {
				if lt, err := strconv.ParseUint(lstr, 10, 64); err == nil {
					l = lt
				}
			}
			return
		}(r)

		// retrieve
		sReq.Query.SetOffset(offset)
		sReq.Query.SetLimit(limit)

		request = sReq
		return
	}

	// decodeJSONReq returns a DecodeRequestFunc that decode request
	// into allocated memory structure
	var decodeJSONReq httptransport.DecodeRequestFunc = func(ctx context.Context, r *http.Request) (request interface{}, err error) {
		return decodeJSONEntity(ctx, r)
	}

	// decodeUpdate returns a DecodeRequestFunc that decode request
	var decodeUpdate httptransport.DecodeRequestFunc = func(ctx context.Context, r *http.Request) (request interface{}, err error) {

		sReq, err := decodeServiceIDReq(ctx, r)
		if err != nil {
			return
		}

		sReq.Payload, err = decodeJSONEntity(ctx, r)
		if err != nil {
			return
		}

		request = sReq
		return
	}

	//
	// ==== httpservce.Services
	//

	// define middleware chains of all RESTful endpoints
	handlers = make(map[string]*httpservice.Service)

	handlers["create"] = httpservice.NewJSONService(
		paths.Plural(), endpoints["create"])
	handlers["create"].Weight = 1
	handlers["create"].Methods = []string{"POST"}
	handlers["create"].DecodeFunc = decodeJSONReq
	handlers["create"].Middlewares.Add(httpservice.MWProtocol, prepareProtocol)
	handlers["create"].Middlewares.Add(httpservice.MWPrepare, prepareCreate)
	handlers["create"].Middlewares.Add(httpservice.MWInner,
		checkPermBefore("create "+noun.Singular()))

	handlers["retrieve"] = httpservice.NewJSONService(
		paths.Singular(), endpoints["retrieve"])
	handlers["retrieve"].Methods = []string{"GET"}
	handlers["retrieve"].DecodeFunc = decodeIDReq
	handlers["retrieve"].Middlewares.Add(httpservice.MWProtocol, prepareProtocol)
	handlers["retrieve"].Middlewares.Add(httpservice.MWPrepare, prepareList)
	handlers["retrieve"].Middlewares.Add(httpservice.MWInner,
		checkPermAfter("retrieve "+noun.Singular()))

	handlers["update"] = httpservice.NewJSONService(
		paths.Singular(), endpoints["update"])
	handlers["update"].Methods = []string{"PUT"}
	handlers["update"].DecodeFunc = decodeUpdate
	handlers["update"].Middlewares.Add(httpservice.MWProtocol, prepareProtocol)
	handlers["update"].Middlewares.Add(httpservice.MWPrepare, prepareUpdate)
	handlers["update"].Middlewares.Add(httpservice.MWInner,
		checkPermBefore("update "+noun.Singular()))

	handlers["list"] = httpservice.NewJSONService(
		paths.Plural(), endpoints["list"])
	handlers["list"].Weight = 1
	handlers["list"].Methods = []string{"GET"}
	handlers["list"].DecodeFunc = decodeListReq
	handlers["list"].Middlewares.Add(httpservice.MWProtocol, prepareProtocol)
	handlers["list"].Middlewares.Add(httpservice.MWPrepare, prepareList)
	handlers["list"].Middlewares.Add(httpservice.MWInner,
		checkPermAfter("list "+noun.Singular()))

	handlers["delete"] = httpservice.NewJSONService(
		paths.Singular(), endpoints["delete"])
	handlers["delete"].Methods = []string{"DELETE"}
	handlers["delete"].DecodeFunc = decodeIDReq
	handlers["delete"].Middlewares.Add(httpservice.MWProtocol, prepareProtocol)
	handlers["delete"].Middlewares.Add(httpservice.MWInner,
		checkPermBefore("delete "+noun.Singular()))

	return
}