Ejemplo n.º 1
0
func TestSetGet(t *testing.T) {

	type tempKey int

	const (
		srcKey tempKey = iota
		key
	)

	randString := func(n int) string {
		var letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
		b := make([]rune, n)
		for i := range b {
			b[i] = letterRunes[rand.Intn(len(letterRunes))]
		}
		return string(b)
	}

	msg := randString(20)
	ch := make(chan int)

	// prepare the context
	ctx := func(msg interface{}, ch chan<- int) context.Context {

		dummySrc := func() (conn store.Conn, err error) {
			conn = tConn{msg, ch}
			return
		}

		factory := store.NewFactory()
		factory.SetSource(srcKey, store.SourceFunc(dummySrc))
		factory.Set(key, srcKey, func(sess interface{}) (s store.Store, err error) {
			err = fmt.Errorf("%s", sess)
			return
		})

		return store.WithFactory(context.Background(), factory)

	}(msg, ch)

	// get a store
	if _, err := store.Get(ctx, key); err == nil {
		t.Error("unexpected nil error")
	} else if want, have := msg, err.Error(); want != have {
		t.Errorf("expected %#v, got %#v", want, have)
	}

	// test if store would close before timeout
	d, _ := time.ParseDuration("1s")
	timeout := time.After(d)
	store.CloseAllIn(ctx)

	select {
	case <-timeout:
		t.Error("tConn not closed before timeout")
	case <-ch:
		t.Log("tConn closed")
	}

}
Ejemplo n.º 2
0
func TestStorage_AuthorizeData(t *testing.T) {

	// define test db
	getContext := func() context.Context {
		factory := store.NewFactory()
		factory.SetSource(store.DefaultSrc, defaultTestSrc())
		factory.Set(oauth2.KeyAccess, store.DefaultSrc, oauth2.AccessDataStoreProvider)
		factory.Set(oauth2.KeyAuth, store.DefaultSrc, oauth2.AuthorizeDataStoreProvider)
		factory.Set(oauth2.KeyClient, store.DefaultSrc, oauth2.ClientStoreProvider)
		factory.Set(oauth2.KeyUser, store.DefaultSrc, oauth2.UserStoreProvider)
		return store.WithFactory(context.Background(), factory)
	}

	// create dummy Client and user
	ctx := getContext()
	defer store.CloseAllIn(ctx)
	storage := &oauth2.Storage{}
	storage.SetContext(ctx)

	c, u := createStoreDummies(ctx, "password", "http://foobar.com/redirect")
	ad := dummyNewAuth(c, u)
	storage.SaveAuthorize(ad.ToOsin())

	// load the osin.AuthorizeData form store
	oad, err := storage.LoadAuthorize(ad.Code)
	if err != nil {
		t.Errorf("error: %#v", err.Error())
	}

	// Test if loaded Client equals to client in original one
	if want, have := c.GetId(), oad.Client.GetId(); want != have {
		t.Errorf("expected %#v, got %#v", want, have)
	}
	if want, have := c.GetRedirectUri(), oad.Client.GetRedirectUri(); want != have {
		t.Errorf("expected %#v, got %#v", want, have)
	}
	if want, have := c.GetSecret(), oad.Client.GetSecret(); want != have {
		t.Errorf("expected %#v, got %#v", want, have)
	}

	// Test if UserData equals to original one
	if u1, u2 := ad.UserData.(*oauth2.User), oad.UserData.(*oauth2.User); true {
	} else if want, have := u1.ID, u2.ID; want != have {
		t.Errorf("expected %#v, got %#v", want, have)
	} else if want, have := u1.Email, u2.Email; want != have {
		t.Errorf("expected %#v, got %#v", want, have)
	} else if want, have := u1.Name, u2.Name; want != have {
		t.Errorf("expected %#v, got %#v", want, have)
	} else if want, have := u1.Password, u2.Password; want != have {
		t.Errorf("expected %#v, got %#v", want, have)
	} else if want, have := u1.Created, u2.Created; want.Unix() != have.Unix() {
		t.Errorf("expected %#v, got %#v", want, have)
	} else if want, have := u1.Updated, u2.Updated; want.Unix() != have.Unix() {
		t.Errorf("expected %#v, got %#v", want, have)
	}

}
Ejemplo n.º 3
0
Archivo: manager.go Proyecto: gourd/kit
// GetEndpoints generate endpoints http handers and return
func (m *Manager) GetEndpoints(factory store.Factory) *Endpoints {

	// try to login with given request login
	tryLogin := func(ctx context.Context, r *http.Request) (user OAuth2User, err error) {

		logger := msg
		logger.Log(
			"func", "tryLogin (Manager.GetEndpoints)")

		// parse POST input
		r.ParseForm()
		if r.Method == "POST" {

			var u OAuth2User
			var us store.Store

			// get and check password non-empty
			password := r.Form.Get("password")
			if password == "" {
				err = errors.New("empty password")
				return
			}

			// obtain user store
			us, err = store.Get(ctx, KeyUser)
			if err != nil {
				err = store.Error(
					http.StatusInternalServerError,
					http.StatusText(http.StatusInternalServerError)).
					TellServer("error obtaining user store: %s", err.Error())
				return
			}

			// get user by userFunc
			u, err = m.userFunc(r, us)
			if err != nil {
				serr := store.ExpandError(err)
				if serr.Status == http.StatusNotFound {
					err = store.Error(http.StatusBadRequest, "user id or password incorrect").
						TellServer("user not found")
				} else {
					err = store.Error(
						http.StatusInternalServerError,
						http.StatusText(http.StatusInternalServerError)).
						TellServer("error obtaining user: %s", serr.ServerMsg)
				}
				return
			}

			// if user is nil, user not found
			if u == nil {
				err = store.Error(http.StatusBadRequest, "user not found")
				return
			}

			// if password does not match
			if !u.PasswordIs(password) {
				err = store.Error(http.StatusBadRequest, "user id or password incorrect").
					TellServer("incorrect password")
				return
			}

			// return pointer of user object, allow it to be re-cast
			logger.Log(
				"func", "tryLogin (Manager.GetEndpoints)",
				"message", "login success")
			user = u
			return
		}

		// no POST input or incorrect login, show form
		// end login handling sequence and wait for
		// user input from login form
		err = store.Error(http.StatusUnauthorized, "Require login").
			TellServer("no POST input")
		return
	}

	type ContextHandlerFunc func(ctx context.Context,
		w http.ResponseWriter, r *http.Request) *osin.Response

	// sessionContext takes a ContextHandlerFunc and returns
	// a http.HandlerFunc
	sessionContext := func(inner ContextHandlerFunc) http.HandlerFunc {
		return func(w http.ResponseWriter, r *http.Request) {
			// per connection based context.Context, with factory
			ctx := store.WithFactory(context.Background(), factory)
			defer store.CloseAllIn(ctx)
			if resp := inner(ctx, w, r); resp != nil {
				if resp.InternalError != nil {
					errLogger := errMsg
					errLogger.Log(
						"func", "sessionContext (Manager.GetEndpoints)",
						"error", resp.InternalError.Error())
				}
				osin.OutputJSON(resp, w, r)
			}
		}
	}

	ep := Endpoints{}

	// authorize endpoint
	ep.Auth = sessionContext(func(ctx context.Context,
		w http.ResponseWriter, r *http.Request) *osin.Response {

		logger := msg
		logger.Log(
			"endpoint", "auth")

		srvr := m.osinServer
		resp := srvr.NewResponse()
		resp.Storage.(*Storage).SetContext(ctx)

		// handle authorize request with osin
		if ar := srvr.HandleAuthorizeRequest(resp, r); ar != nil {
			logger.Log(
				"endpoint", "auth",
				"message", "handle authorize request")

			// TODO: maybe redirect to another URL for
			//       dedicated login form flow?
			var err error
			if ar.UserData, err = tryLogin(ctx, r); err != nil {
				serr := store.ExpandError(err)
				logger.Log(
					"endpoint", "auth",
					"message", "handle authorize request",
					"error", serr.ServerMsg)

				lctx := &LoginFormContext{
					Context:        withOsinAuthRequest(ctx, ar),
					LoginErr:       err,
					ResponseWriter: w,
					Request:        r,
					Logger:         logger,
				}
				m.showLoginForm(lctx, w, r)
				return nil
			}

			logger.Log(
				"endpoint", "auth",
				"message", "User obtained",
				"osin.AuthorizeData.UserData", fmt.Sprintf("%#v", ar.UserData))

			ar.Authorized = true
			srvr.FinishAuthorizeRequest(resp, r, ar)
		}

		logger.Log(
			"endpoint", "auth",
			"message", "User obtained",
			"response", fmt.Sprintf("%#v", resp))

		return resp
	})

	// token endpoint
	ep.Token = sessionContext(func(ctx context.Context,
		w http.ResponseWriter, r *http.Request) *osin.Response {

		logger := msg
		logger.Log(
			"endpoint", "token")

		srvr := m.osinServer
		resp := srvr.NewResponse()
		resp.Storage.(*Storage).SetContext(ctx)

		if ar := srvr.HandleAccessRequest(resp, r); ar != nil {
			// TODO: handle authorization
			// check if the user has the permission to grant the scope
			logger.Log(
				"endpoint", "token",
				"message", "access successful")
			ar.Authorized = true
			srvr.FinishAccessRequest(resp, r, ar)
		}

		logger.Log(
			"endpoint", "token",
			"response", fmt.Sprintf("%#v", resp))
		return resp
	})

	// information endpoint
	ep.Info = sessionContext(func(ctx context.Context,
		w http.ResponseWriter, r *http.Request) *osin.Response {

		logger := msg
		logger.Log(
			"endpoint", "information")

		srvr := m.osinServer

		resp := srvr.NewResponse()
		resp.Storage.(*Storage).SetContext(ctx)
		defer resp.Close()

		if ir := srvr.HandleInfoRequest(resp, r); ir != nil {
			srvr.FinishInfoRequest(resp, r, ir)
		}

		logger.Log(
			"endpoint", "information",
			"response", fmt.Sprintf("%#v", resp))
		return resp
	})

	return &ep

}
Ejemplo n.º 4
0
// TestOAuth2HTTP tests the stack with
// actual HTTP call against httptest.Server
// wrapped handlers
func TestOAuth2HTTP(t *testing.T) {

	var err error

	// a dummy password for dummy user
	password := "******"
	message := "Success"

	// test store context
	type tempKey int
	const (
		testDB tempKey = iota
	)
	factory := store.NewFactory()
	factory.SetSource(testDB, defaultTestSrc())
	factory.Set(oauth2.KeyUser, testDB, oauth2.UserStoreProvider)
	factory.Set(oauth2.KeyClient, testDB, oauth2.ClientStoreProvider)
	factory.Set(oauth2.KeyAccess, testDB, oauth2.AccessDataStoreProvider)
	factory.Set(oauth2.KeyAuth, testDB, oauth2.AuthorizeDataStoreProvider)
	ctx := store.WithFactory(context.Background(), factory)

	testCtx := &testContext{
		password:     password,
		t:            t,
		redirectBase: "https://test.foobar/example_app/",
		redirectURL:  "https://test.foobar/example_app/code",
		oauth2Path:   "/oauth2",
	}

	// create test oauth2 server
	ts := httptest.NewServer(testOAuth2Server(t, testCtx.oauth2Path, message))
	defer ts.Close()
	testCtx.oauth2Base = ts.URL

	t.Logf("auth endpoint %#v", testCtx.AuthEndpoint())

	// create dummy client and user
	testCtx.client, testCtx.user = createStoreDummies(ctx, testCtx.password, testCtx.redirectBase)
	store.CloseAllIn(ctx)

	// create dummy oauth client and user
	testCtx.code, err = getCodeHTTP(t, getCodeRequest(testCtx))
	if err != nil {
		t.Error(err.Error())
		return
	}

	// retrieve token from token endpoint
	// get response from client web app redirect uri
	testCtx.token, testCtx.refresh, err = getTokenHTTP(t, getTokenRequest(testCtx))
	if err != nil {
		t.Errorf(err.Error())
		return
	}

	// try to refresh token
	t.Logf(`refresh_token=%s token=%s msg="refresh token test"`, testCtx.refresh, testCtx.token)
	testCtx.token, testCtx.refresh, err = getTokenHTTP(t, getRefreshRequest(testCtx))
	if err != nil {
		t.Errorf(err.Error())
		return
	}
	t.Logf(`refresh_token=%s token=%s msg="refresh token test success"`, testCtx.refresh, testCtx.token)

	// retrieve a testing content path
	body, err := getContentHTTP(t, getContentRequest(testCtx.token, ts.URL+"/content"))
	if err != nil {
		t.Logf("hello: %#v", err)
		t.Errorf(err.Error())
		return
	}

	// final result
	if want, have := message, body; want != have {
		t.Errorf("expected: %#v, got: %#v", want, have)
	}
	t.Logf("result: %#v", string(body))

}
Ejemplo n.º 5
0
func TestGetAccess_Session(t *testing.T) {

	var err error

	// test oauth2 server (router only)
	testCtx := &testContext{
		password:     "******",
		t:            t,
		redirectBase: "https://test.foobar/example_app/",
		redirectURL:  "https://test.foobar/example_app/code",
		oauth2Path:   "/oauth2/dummy",
	}

	message := "Success"
	oauth2Srvr := testOAuth2Server(t, testCtx.oauth2Path, message)
	contentURL := "/content"

	// test oauth2 client app (router only)
	//redirectURL := "/application/redirect"

	// test store context
	type tempKey int
	const (
		testDB tempKey = iota
	)
	factory := store.NewFactory()
	factory.SetSource(testDB, defaultTestSrc())
	factory.Set(oauth2.KeyUser, testDB, oauth2.UserStoreProvider)
	factory.Set(oauth2.KeyClient, testDB, oauth2.ClientStoreProvider)
	factory.Set(oauth2.KeyAccess, testDB, oauth2.AccessDataStoreProvider)
	factory.Set(oauth2.KeyAuth, testDB, oauth2.AuthorizeDataStoreProvider)
	ctx := store.WithFactory(context.Background(), factory)
	defer store.CloseAllIn(ctx)

	// create dummy oauth client and user
	testCtx.client, testCtx.user = createStoreDummies(ctx,
		testCtx.password, testCtx.redirectBase)

	// run the code request
	testCtx.code, err = getCode(oauth2Srvr, getCodeRequest(testCtx))
	if err != nil {
		t.Errorf("getCode error (%#v)", err.Error())
		return
	}
	t.Logf("code:  %#v", testCtx.code)

	// get oauth2 token
	testCtx.token, err = getToken(oauth2Srvr, getTokenRequest(testCtx))
	if err != nil {
		t.Errorf("getToken error (%s)", err.Error())
		return
	}
	t.Logf("token: %#v", testCtx.token)

	if want, have := (*oauth2.AccessData)(nil), oauth2.GetAccess(ctx); want != have {
		t.Errorf("expected %#v, got %#v", want, have)
	}

	// middleware routine: WithAccess set context with proper token passed
	// test getting AccessData from supposed context with AccessData
	r := getContentRequest(testCtx.token, contentURL)
	ctx = oauth2.LoadTokenAccess(oauth2.UseToken(ctx, r))
	access := oauth2.GetAccess(ctx)
	if access == nil {
		t.Errorf("expected *AccessData, got %#v", access)
		return
	}

	if want, have := "", access.ID; want != have {
		t.Errorf("expect %#v, got %#v", want, have)
	}
	if access.ClientID == "" {
		t.Errorf("access.ClientId expected to be not empty")
	}
	if want, have := testCtx.token, access.AccessToken; want != have {
		t.Errorf("expect %#v, got %#v", want, have)
	}
	if want, have := testCtx.user.ID, access.UserID; want != have {
		t.Errorf("expect %#v, got %#v", want, have)
	}
	if access.UserData == nil {
		t.Error("expect access.UserData not nil")
	} else if want, have := testCtx.user.ID, access.UserData.(*oauth2.User).ID; want != have {
		t.Errorf("expect %#v, got %#v", want, have)
	}
	if access.RefreshToken == "" {
		t.Errorf("access.RefreshToken expected to be not empty")
	}

}
Ejemplo n.º 6
0
func TestStorage_AccessData(t *testing.T) {

	authEqual := func(a1, a2 *osin.AuthorizeData) (err error) {
		if a1 == nil {
			err = fmt.Errorf("unexpected nil a1")
			return
		}
		if a2 == nil {
			err = fmt.Errorf("unexpected nil a2")
			return
		}

		if v1, v2 := a1.Code, a2.Code; v1 != v2 {
			err = fmt.Errorf("Code not equal. %#v != %#v", v1, v2)
			return
		}
		if v1, v2 := a1.ExpiresIn, a2.ExpiresIn; v1 != v2 {
			err = fmt.Errorf("Code not equal. %#v != %#v", v1, v2)
			return
		}
		return
	}

	accessMatch := func(access *oauth2.AccessData, oaccess *osin.AccessData) (err error) {
		if access == nil {
			err = fmt.Errorf("unexpected nil *oauth2.AccessData")
			return
		}
		if oaccess == nil {
			err = fmt.Errorf("unexpected nil *osin.AccessData")
			return
		}

		if v1, v2 := access.AccessToken, oaccess.AccessToken; v1 != v2 {
			err = fmt.Errorf("AccessToken mismatch.\n*oauth2.AccessData=%#v, *osin.AccessData=%#v",
				v1, v2)
			return
		}
		if v1, v2 := access.RefreshToken, oaccess.RefreshToken; v1 != v2 {
			err = fmt.Errorf("RefreshToken mismatch.\n*oauth2.RefreshData=%#v, *osin.RefreshData=%#v",
				v1, v2)
			return
		}
		if v1, v2 := access.ExpiresIn, oaccess.ExpiresIn; v1 != v2 {
			err = fmt.Errorf("ExpiresIn mismatch.\n*oauth2.ExpiresIn=%#v, *osin.ExpiresIn=%#v",
				v1, v2)
			return
		}
		if v1, v2 := access.Scope, oaccess.Scope; v1 != v2 {
			err = fmt.Errorf("Scope mismatch.\n*oauth2.Scope=%#v, *osin.Scope=%#v",
				v1, v2)
			return
		}
		if v1, v2 := access.RedirectURI, oaccess.RedirectUri; v1 != v2 {
			err = fmt.Errorf("RedirectUri mismatch.\n*oauth2.RedirectUri=%#v, *osin.RedirectUri=%#v",
				v1, v2)
			return
		}
		if v1, v2 := access.CreatedAt, oaccess.CreatedAt; v1.Unix() != v2.Unix() {
			err = fmt.Errorf("CreatedAt mismatch.\n*oauth2.CreatedAt=%#v, *osin.CreatedAt=%#v",
				v1, v2)
			return
		}
		return
	}

	// define test db
	getContext := func() context.Context {
		factory := store.NewFactory()
		factory.SetSource(store.DefaultSrc, defaultTestSrc())
		factory.Set(oauth2.KeyAccess, store.DefaultSrc, oauth2.AccessDataStoreProvider)
		factory.Set(oauth2.KeyAuth, store.DefaultSrc, oauth2.AuthorizeDataStoreProvider)
		factory.Set(oauth2.KeyClient, store.DefaultSrc, oauth2.ClientStoreProvider)
		factory.Set(oauth2.KeyUser, store.DefaultSrc, oauth2.UserStoreProvider)
		return store.WithFactory(context.Background(), factory)
	}

	// create dummy Client and user
	ctx := getContext()
	defer store.CloseAllIn(ctx)
	storage := &oauth2.Storage{}
	storage.SetContext(ctx)

	c, u := createStoreDummies(ctx, "password", "http://foobar.com/redirect")
	ad := dummyNewAuth(c, u)
	access1 := dummyNewAccess(c, u, ad, nil)

	storage.SaveAccess(access1.ToOsin())
	oaccess1, err := storage.LoadAccess(access1.AccessToken)
	if err != nil {
		t.Errorf("unexpected error %#v", err.Error())
		return
	}

	if err := accessMatch(access1, oaccess1); err != nil {
		t.Errorf("access1 != oaccess1, err = %#v", err.Error())
	}
	if err := authEqual(access1.AuthorizeData.ToOsin(), oaccess1.AuthorizeData); err != nil {
		t.Errorf("access1.AuthorizeData != oaccess1.AuthorizeData, err = %#v", err.Error())
		t.Logf("\naccess1.AuthorizeData=%#v\noaccess1.AuthorizeData=%#v",
			access1.AuthorizeData, oaccess1.AuthorizeData)
	}

	access2 := dummyNewAccess(c, u, ad, access1)
	if access2.AccessData == nil {
		t.Error("unexpected nil value")
	} else if access2.ToOsin().AccessData == nil {
		t.Error("unexpected nil value")
	}

	storage.SaveAccess(access2.ToOsin())
	oaccess2, err := storage.LoadAccess(access2.AccessToken)
	if err != nil {
		t.Errorf("unexpected error %#v", err.Error())
		return
	}

	if err := accessMatch(access2, oaccess2); err != nil {
		t.Errorf("access2 != oaccess2, err = %#v", err.Error())
	}
	if err := authEqual(access2.AuthorizeData.ToOsin(), oaccess2.AuthorizeData); err != nil {
		t.Errorf("access2.AuthorizeData != oaccess2.AuthorizeData, err = %#v", err.Error())
		t.Logf("\naccess2.AuthorizeData=%#v\noaccess2.AuthorizeData=%#v",
			access2.AuthorizeData, oaccess2.AuthorizeData)
	}
	if err := accessMatch(access1, oaccess2.AccessData); err != nil {
		t.Errorf("access1 != oaccess2.AccessData, err = %#v", err.Error())
	}

}