예제 #1
0
func TestVerifySignedJwt(t *testing.T) {
	defer stubMemcacheGetCerts()()
	r, deleteAppengineContext := tu.NewTestRequest("GET", "/", nil)
	defer deleteAppengineContext()

	tts := []struct {
		token    string
		now      time.Time
		expected *signedJWT
	}{
		{jwtValidTokenString, jwtValidTokenTime, &jwtValidTokenObject},
		{jwtValidTokenString, jwtValidTokenTime.Add(time.Hour * 24), nil},
		{jwtValidTokenString, jwtValidTokenTime.Add(-time.Hour * 24), nil},
		{jwtInvalidKeyToken, jwtValidTokenTime, nil},
		{jwtInvalidAlgToken, jwtValidTokenTime, nil},
		{"invalid.token", jwtValidTokenTime, nil},
		{"another.invalid.token", jwtValidTokenTime, nil},
	}

	c := NewContext(r)

	for i, tt := range tts {
		jwt, err := verifySignedJwt(c, tt.token, tt.now.Unix())
		switch {
		case err != nil && tt.expected != nil:
			t.Errorf("%d: didn't expect error: %v", i, err)
		case err == nil && tt.expected == nil:
			t.Errorf("%d: expected error, got: %#v", i, jwt)
		case err == nil && tt.expected != nil:
			assertEquals(t, i, jwt, tt.expected)
		}
	}
}
예제 #2
0
func TestTokeninfoContextCurrentOAuthClientID(t *testing.T) {
	const token = "some_token"

	type test struct {
		token, scope, clientId string
		httpStatus             int32
		content                []byte
		fetchErr               error
	}

	var currTT *test

	fetchStub := func(in, out proto.Message, _ *tu.RpcCallOptions) error {
		req := in.(*fetch_pb.URLFetchRequest)
		url := tokeninfoEndpointUrl + "?access_token=" + token
		if req.GetUrl() != url {
			t.Errorf("fetch: expected URL %q, got %q", url, req.GetUrl())
		}
		resp := out.(*fetch_pb.URLFetchResponse)
		resp.StatusCode = proto.Int32(currTT.httpStatus)
		resp.Content = currTT.content
		return currTT.fetchErr
	}
	defer tu.RegisterAPIOverride("urlfetch", "Fetch", fetchStub)()

	r, deleteCtx := tu.NewTestRequest("GET", "/", nil)
	defer deleteCtx()

	tts := []*test{
		// token, scope, clientId, httpStatus, content, fetchErr
		{token, "scope.one", "my-client-id", 200, tokeninfoValid, nil},
		{token, "scope.two", "my-client-id", 200, tokeninfoValid, nil},
		{token, "scope.one", "", 200, tokeninfoUnverified, nil},
		{token, "scope.one", "", 200, tokeninfoInvalidEmail, nil},
		{token, "scope.one", "", 401, tokeninfoError, nil},
		{token, "invalid.scope", "", 200, tokeninfoValid, nil},
		{token, "scope.one", "", 400, []byte("{}"), nil},
		{token, "scope.one", "", 200, []byte(""), nil},
		{token, "scope.one", "", -1, nil, errors.New("Fake urlfetch error")},
		{"", "scope.one", "", 200, tokeninfoValid, nil},
	}

	c := tokeninfoContextFactory(r)
	for i, tt := range tts {
		currTT = tt
		r.Header.Set("authorization", "bearer "+tt.token)
		id, err := c.CurrentOAuthClientID(tt.scope)
		switch {
		case err != nil && tt.clientId != "":
			t.Errorf("%d: expected %q, got error %v", i, tt.clientId, err)
		case err == nil && tt.clientId == "":
			t.Errorf("%d: expected error, got %q", i, id)
		case err == nil && id != tt.clientId:
			t.Errorf("%d: expected %q, got %q", i, tt.clientId, id)
		}
	}
}
예제 #3
0
func TestCurrentIDTokenUser(t *testing.T) {
	jwtOrigParser := jwtParser
	defer func() {
		jwtParser = jwtOrigParser
	}()
	r, deleteAppengineContext := tu.NewTestRequest("GET", "/", nil)
	defer deleteAppengineContext()
	c := NewContext(r)

	aud := []string{jwtValidTokenObject.Audience, jwtValidTokenObject.ClientID}
	azp := []string{jwtValidTokenObject.ClientID}

	jwtUnacceptedToken := signedJWT{
		Audience: "my-other-client-id",
		ClientID: "my-other-client-id",
		Email:    "*****@*****.**",
		Expires:  1370352252,
		IssuedAt: 1370348652,
		Issuer:   "accounts.google.com",
	}

	tts := []struct {
		token         *signedJWT
		expectedEmail string
	}{
		{&jwtValidTokenObject, jwtValidTokenObject.Email},
		{&jwtUnacceptedToken, ""},
		{nil, ""},
	}

	var currToken *signedJWT

	jwtParser = func(Context, string, int64) (*signedJWT, error) {
		if currToken == nil {
			return nil, errors.New("Fake verification failed")
		}
		return currToken, nil
	}

	for i, tt := range tts {
		currToken = tt.token
		user, err := currentIDTokenUser(c,
			jwtValidTokenString, aud, azp, jwtValidTokenTime.Unix())
		switch {
		case tt.expectedEmail != "" && err != nil:
			t.Errorf("%d: unexpected error: %v", i, err)
		case tt.expectedEmail == "" && err == nil:
			t.Errorf("%d: expected error, got: %#v", i, user)
		case err == nil && tt.expectedEmail != user.Email:
			t.Errorf("%d: expected %q, got %q", i, tt.expectedEmail, user.Email)
		}
	}
}
예제 #4
0
func TestGetCachedCertsCacheHit(t *testing.T) {
	var cacheValue []byte
	mcGetStub := func(in, out proto.Message, _ *tu.RpcCallOptions) error {
		req := in.(*mc_pb.MemcacheGetRequest)
		if req.GetNameSpace() != certNamespace {
			t.Errorf("memcache: expected %q ns, got %q",
				req.GetNameSpace(), certNamespace)
		}

		item := &mc_pb.MemcacheGetResponse_Item{
			Key:   req.Key[0],
			Value: cacheValue,
		}
		resp := out.(*mc_pb.MemcacheGetResponse)
		resp.Item = []*mc_pb.MemcacheGetResponse_Item{item}
		return nil
	}
	defer tu.RegisterAPIOverride("memcache", "Get", mcGetStub)()
	r, deleteAppengineContext := tu.NewTestRequest("GET", "/", nil)
	defer deleteAppengineContext()

	tts := []struct {
		cacheValue string
		expected   *certsList
	}{
		{"", nil},
		{"{}", &certsList{}},
		{`{"keyvalues": [{}]}`, &certsList{[]*certInfo{{}}}},
		{`{"keyvalues": [
	    	{"algorithm": "RS256",
	    	 "exponent": "123",
	    	 "keyid": "some-id",
	    	 "modulus": "123"} ]}`,
			&certsList{[]*certInfo{{"RS256", "123", "some-id", "123"}}}},
	}
	for i, tt := range tts {
		cacheValue = []byte(tt.cacheValue)
		out, err := getCachedCerts(NewContext(r))
		switch {
		case err != nil && tt.expected != nil:
			t.Errorf("%d: didn't expect error %v", i, err)
		case err == nil && tt.expected == nil:
			t.Errorf("%d: expected error, got %#v", i, out)
		case err == nil && tt.expected != nil:
			assertEquals(t, i, out, tt.expected)
		}
	}
}
예제 #5
0
func TestGetOk(t *testing.T) {
	const itemId = "valid-id"

	r, deleteContext := tu.NewTestRequest("GET", "/"+itemId, nil)
	defer deleteContext()
	w := httptest.NewRecorder()

	get(w, r)

	if w.Code != http.StatusOK {
		t.Errorf("Expected 200, got %d", w.Code)
	}
	body := string(w.Body.Bytes())
	if body != itemId {
		t.Errorf("Expected %q, got %q", itemId, body)
	}
}
예제 #6
0
파일: items_test.go 프로젝트: ritoon/aegot
func TestPutItem(t *testing.T) {
	const (
		itemId   = "some-id"
		itemName = "test"
	)

	putStub := func(in, out proto.Message, _ *tu.RpcCallOptions) error {
		req := in.(*pb.PutRequest)

		if len(req.GetEntity()) != 1 {
			t.Error("Expected 1 entity, got %d", len(req.GetEntity()))
		}
		ent := req.GetEntity()[0]
		id := ent.GetKey().GetPath().GetElement()[0].GetName()
		if id != itemId {
			t.Error("Expected ID %q, got %q", itemId, id)
		}
		if len(ent.GetProperty()) != 1 {
			t.Error("Expected 1 property, got: %d", len(ent.GetProperty()))
		}
		prop := ent.GetProperty()[0]
		if prop.GetName() != "Name" {
			t.Error("Invalid property name: %q", prop.GetName())
		}
		val := prop.GetValue().GetStringValue()
		if val != itemName {
			t.Error("Expected %q, got %q", itemName, val)
		}

		resp := out.(*pb.PutResponse)
		resp.Key = []*pb.Reference{ent.GetKey()}
		return nil
	}
	unregister := tu.RegisterAPIOverride("datastore_v3", "Put", putStub)
	defer unregister()

	r, deleteContext := tu.NewTestRequest("PUT", "/"+itemId, nil)
	defer deleteContext()

	item := Item{Id: itemId, Name: itemName}
	// appengine.NewContext() will use the one created in NewTestRequest() above
	if err := item.put(appengine.NewContext(r)); err != nil {
		t.Error(err)
	}
}
예제 #7
0
func TestGetErrors(t *testing.T) {
	tt := []*struct {
		path string
		code int
	}{
		{"/does-not-exist", 404},
		{"/error", 500},
	}
	for _, ti := range tt {
		r, deleteContext := tu.NewTestRequest("GET", ti.path, nil)
		defer deleteContext()
		w := httptest.NewRecorder()

		get(w, r)

		if w.Code != ti.code {
			t.Errorf("Expected %d, got %d", ti.code, w.Code)
		}
	}
}
예제 #8
0
func TestVerifyParsedToken(t *testing.T) {
	const (
		goog     = "accounts.google.com"
		clientId = "my-client-id"
		email    = "*****@*****.**"
	)
	audiences := []string{clientId, "hello-android"}
	clientIds := []string{clientId}

	tts := []struct {
		issuer, audience, clientId, email string
		valid                             bool
	}{
		{goog, clientId, clientId, email, true},
		{goog, "hello-android", clientId, email, true},
		{goog, "invalid", clientId, email, false},
		{goog, clientId, "invalid", email, false},
		{goog, clientId, clientId, "", false},
		{"", clientId, clientId, email, false},
	}

	r, deleteCtx := tu.NewTestRequest("GET", "/", nil)
	defer deleteCtx()

	c := NewContext(r)

	for i, tt := range tts {
		jwt := signedJWT{
			Issuer:   tt.issuer,
			Audience: tt.audience,
			ClientID: tt.clientId,
			Email:    tt.email,
		}
		out := verifyParsedToken(c, jwt, audiences, clientIds)
		if tt.valid != out {
			t.Errorf("%d: expected token to be valid? %v, got: %v",
				i, tt.valid, out)
		}
	}
}
예제 #9
0
func TestTokeninfoCurrentOAuthUser(t *testing.T) {
	fetchStub := func(in, out proto.Message, _ *tu.RpcCallOptions) error {
		resp := out.(*fetch_pb.URLFetchResponse)
		resp.StatusCode = proto.Int32(200)
		resp.Content = tokeninfoValid
		return nil
	}
	defer tu.RegisterAPIOverride("urlfetch", "Fetch", fetchStub)()

	r, deleteCtx := tu.NewTestRequest("GET", "/", nil)
	defer deleteCtx()
	r.Header.Set("authorization", "bearer some_token")

	c := tokeninfoContextFactory(r)
	user, err := c.CurrentOAuthUser("scope.one")
	if err != nil {
		t.Errorf("unexpected error: %v", err)
	}
	if user.Email != tokeinfoEmail {
		t.Errorf("expected email %q, got %q", tokeinfoEmail, user.ID)
	}
}
예제 #10
0
func TestCurrentUser(t *testing.T) {
	const (
		clientId    = "my-client-id"
		bearerEmail = "*****@*****.**"
		validScope  = "valid.scope"
	)

	getOAuthRPC := func(in, out proto.Message, _ *tu.RpcCallOptions) error {
		scope := in.(*user_pb.GetOAuthUserRequest).GetScope()
		if scope != validScope && scope != EmailScope {
			return fmt.Errorf("Invalid scope: %q", scope)
		}
		resp := out.(*user_pb.GetOAuthUserResponse)
		resp.ClientId = proto.String(clientId)
		resp.Email = proto.String(bearerEmail)
		resp.AuthDomain = proto.String("example.org")
		resp.UserId = proto.String("12345")
		return nil
	}
	defer tu.RegisterAPIOverride("user", "GetOAuthUser", getOAuthRPC)()

	// stubs to make fake JWT token validations pass
	defer stubMemcacheGetCerts()() // in jwt_test.go
	origCurrentUTC := currentUTC
	defer func() {
		currentUTC = origCurrentUTC
	}()
	currentUTC = func() time.Time {
		return jwtValidTokenTime
	}

	jwtStr, jwt := jwtValidTokenString, jwtValidTokenObject
	tts := []struct {
		token                        string
		scopes, audiences, clientIDs []string
		expectedEmail                string
	}{
		// success
		{jwtStr, []string{EmailScope}, []string{jwt.Audience}, []string{jwt.ClientID}, jwt.Email},
		{"ya29.token", []string{EmailScope}, []string{clientId}, []string{clientId}, bearerEmail},
		{"ya29.token", []string{EmailScope, validScope}, []string{clientId}, []string{clientId}, bearerEmail},
		{"1/token", []string{validScope}, []string{clientId}, []string{clientId}, bearerEmail},

		// failure
		{jwtStr, []string{EmailScope}, []string{"other-client"}, []string{"other-client"}, ""},
		{"some.invalid.jwt", []string{EmailScope}, []string{jwt.Audience}, []string{jwt.ClientID}, ""},
		{"", []string{validScope}, []string{clientId}, []string{clientId}, ""},
		{"ya29.invalid", []string{"invalid.scope"}, []string{clientId}, []string{clientId}, ""},

		{"doesn't matter", nil, []string{clientId}, []string{clientId}, ""},
		{"doesn't matter", []string{EmailScope}, nil, []string{clientId}, ""},
		{"doesn't matter", []string{EmailScope}, []string{clientId}, nil, ""},
	}

	for i, tt := range tts {
		req, deleteCtx := tu.NewTestRequest("GET", "/", nil)
		defer deleteCtx()
		c := cachingContextFactory(req)
		if tt.token != "" {
			req.Header.Set("authorization", "oauth "+tt.token)
		}

		user, err := CurrentUser(c, tt.scopes, tt.audiences, tt.clientIDs)
		switch {
		case tt.expectedEmail == "" && err == nil:
			t.Errorf("%d: expected error, got %#v", i, user)
		case tt.expectedEmail != "" && user == nil:
			t.Errorf("%d: expected user object, got nil (%v)", i, err)
		case tt.expectedEmail != "" && tt.expectedEmail != user.Email:
			t.Errorf("%d: expected %q, got %q", i, tt.expectedEmail, user.Email)
		}
	}
}
예제 #11
0
func TestCurrentBearerTokenUser(t *testing.T) {
	var (
		validScope    = "valid.scope"
		validClientId = "my-client-id"

		email      = "*****@*****.**"
		userId     = "12345"
		authDomain = "gmail.com"
		isAdmin    = false

		empty = []string{}
	)

	getOAuthUser := func(in, out proto.Message, _ *tu.RpcCallOptions) error {
		scope := in.(*user_pb.GetOAuthUserRequest).GetScope()
		if scope != validScope {
			return fmt.Errorf("Invalid scope: %q", scope)
		}
		resp := out.(*user_pb.GetOAuthUserResponse)
		resp.ClientId = proto.String(validClientId)
		resp.Email = proto.String(email)
		resp.UserId = proto.String(userId)
		resp.AuthDomain = proto.String(authDomain)
		resp.IsAdmin = proto.Bool(isAdmin)
		return nil
	}
	unregister := tu.RegisterAPIOverride("user", "GetOAuthUser", getOAuthUser)
	defer unregister()

	req, deleteAppengineCtx := tu.NewTestRequest("GET", "/", nil)
	defer deleteAppengineCtx()
	c := cachingContextFactory(req)

	tt := []*struct {
		scopes    []string
		clientIDs []string
		success   bool
	}{
		{empty, empty, false},
		{empty, []string{validClientId}, false},
		{[]string{validScope}, empty, false},
		{[]string{validScope}, []string{validClientId}, true},
		{[]string{"a", validScope, "b"}, []string{"c", validClientId, "d"}, true},
	}
	for _, elem := range tt {
		user, err := CurrentBearerTokenUser(c, elem.scopes, elem.clientIDs)
		switch {
		case elem.success && (err != nil || user == nil):
			t.Errorf("Did not expect the call to fail with "+
				"scopes=%v ids=%v. User: %+v, Error: %q",
				elem.scopes, elem.clientIDs, err, user)
		case !elem.success && err == nil:
			t.Errorf("Expected an error, got nil: scopes=%v ids=%v",
				elem.scopes, elem.clientIDs)
		}
	}

	scopes := []string{validScope}
	clientIDs := []string{validClientId}
	user, _ := CurrentBearerTokenUser(c, scopes, clientIDs)
	const failMsg = "Expected %q, got %q"
	if user.ID != userId {
		t.Errorf(failMsg, userId, user.ID)
	}
	if user.Email != email {
		t.Errorf(failMsg, email, user.Email)
	}
	if user.AuthDomain != authDomain {
		t.Errorf(failMsg, authDomain, user.AuthDomain)
	}
	if user.Admin != isAdmin {
		t.Errorf(failMsg, isAdmin, user.Admin)
	}
}
예제 #12
0
func TestGetCachedCertsCacheMiss(t *testing.T) {
	type tt struct {
		mcGetErr, mcSetErr, fetchErr error
		respStatus                   int32
		respContent                  []byte
		cacheControl, age            string

		expected        *certsList
		shouldCallMcSet bool
	}
	var (
		i           int
		currTT      *tt
		mcSetCalled bool
	)

	mcGetStub := func(in, out proto.Message, _ *tu.RpcCallOptions) error {
		return currTT.mcGetErr
	}
	mcSetStub := func(in, out proto.Message, _ *tu.RpcCallOptions) error {
		mcSetCalled = true
		req := in.(*mc_pb.MemcacheSetRequest)
		verifyTT(t,
			req.GetNameSpace(), certNamespace,
			string(req.GetItem()[0].Value), string(currTT.respContent))
		return currTT.mcSetErr
	}
	fetchStub := func(in, out proto.Message, _ *tu.RpcCallOptions) error {
		resp := out.(*fetch_pb.URLFetchResponse)
		resp.StatusCode = proto.Int32(currTT.respStatus)
		resp.Content = currTT.respContent
		resp.Header = []*fetch_pb.URLFetchResponse_Header{
			{
				Key:   proto.String("cache-control"),
				Value: proto.String(currTT.cacheControl),
			},
			{
				Key:   proto.String("age"),
				Value: proto.String(currTT.age),
			},
		}
		return currTT.fetchErr
	}
	defer tu.RegisterAPIOverride("memcache", "Get", mcGetStub)()
	defer tu.RegisterAPIOverride("memcache", "Set", mcSetStub)()
	defer tu.RegisterAPIOverride("urlfetch", "Fetch", fetchStub)()
	r, deleteAppengineContext := tu.NewTestRequest("GET", "/", nil)
	defer deleteAppengineContext()

	tts := []*tt{
		// mcGet, mcSet, fetch err, http status, content,
		// cache, age, expected, should mcSet?
		{memcache.ErrCacheMiss, nil, nil, 200, []byte(`{"keyvalues":null}`),
			"max-age=3600", "600", &certsList{}, true},
		{memcache.ErrServerError, nil, nil, 200, []byte(`{"keyvalues":null}`),
			"max-age=3600", "600", &certsList{}, false},
		{memcache.ErrCacheMiss, memcache.ErrServerError, nil, 200,
			[]byte(`{"keyvalues":null}`),
			"max-age=3600", "600", &certsList{}, true},
		{memcache.ErrCacheMiss, nil, errors.New("fetch RPC error"), 0, nil,
			"", "", nil, false},
		{memcache.ErrCacheMiss, nil, nil, 400, []byte(""),
			"", "", nil, false},
		{memcache.ErrCacheMiss, nil, nil, 200, []byte(`{"keyvalues":null}`),
			"", "", &certsList{}, false},
	}

	c := NewContext(r)

	for i, currTT = range tts {
		mcSetCalled = false
		out, err := getCachedCerts(c)
		switch {
		case err != nil && currTT.expected != nil:
			t.Errorf("%d: unexpected error: %v", i, err)
		case err == nil && currTT.expected == nil:
			t.Errorf("%d: expected error, got %#v", i, out)
		default:
			assertEquals(t, i, out, currTT.expected)
			if currTT.shouldCallMcSet != mcSetCalled {
				t.Errorf("%d: mc set called? %v, expected: %v",
					i, mcSetCalled, currTT.shouldCallMcSet)
			}
		}
	}
}