예제 #1
0
파일: server.go 프로젝트: derekparker/dex
func (s *Server) ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, error) {
	ok, err := s.ClientIdentityRepo.Authenticate(creds)
	if err != nil {
		log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
		return nil, oauth2.NewError(oauth2.ErrorServerError)
	}
	if !ok {
		return nil, oauth2.NewError(oauth2.ErrorInvalidClient)
	}

	signer, err := s.KeyManager.Signer()
	if err != nil {
		log.Errorf("Failed to generate ID token: %v", err)
		return nil, oauth2.NewError(oauth2.ErrorServerError)
	}

	now := time.Now()
	exp := now.Add(s.SessionManager.ValidityWindow)
	claims := oidc.NewClaims(s.IssuerURL.String(), creds.ID, creds.ID, now, exp)
	claims.Add("name", creds.ID)

	jwt, err := jose.NewSignedJWT(claims, signer)
	if err != nil {
		log.Errorf("Failed to generate ID token: %v", err)
		return nil, oauth2.NewError(oauth2.ErrorServerError)
	}

	log.Infof("Client token sent: clientID=%s", creds.ID)

	return jwt, nil
}
예제 #2
0
파일: server.go 프로젝트: GamerockSA/dex
// addClaimsFromScope adds claims that are based on the scopes that the client requested.
// Currently, these include cross-client claims (aud, azp).
func (s *Server) addClaimsFromScope(claims jose.Claims, scopes scope.Scopes, clientID string) error {
	crossClientIDs := scopes.CrossClientIDs()
	if len(crossClientIDs) > 0 {
		var aud []string
		for _, id := range crossClientIDs {
			if clientID == id {
				aud = append(aud, id)
				continue
			}
			allowed, err := s.CrossClientAuthAllowed(clientID, id)
			if err != nil {
				log.Errorf("Failed to check cross client auth. reqClientID %v; authClient:ID %v; err: %v", clientID, id, err)
				return oauth2.NewError(oauth2.ErrorServerError)
			}
			if !allowed {
				err := oauth2.NewError(oauth2.ErrorInvalidRequest)
				err.Description = fmt.Sprintf(
					"%q is not authorized to perform cross-client requests for %q",
					clientID, id)
				return err
			}
			aud = append(aud, id)
		}
		if len(aud) == 1 {
			claims.Add("aud", aud[0])
		} else {
			claims.Add("aud", aud)
		}
		claims.Add("azp", clientID)
	}
	return nil
}
예제 #3
0
파일: grpc.go 프로젝트: otsimo/accounts
func (s *grpcServer) Token(userID, clientID string, iat, exp time.Time) (*jose.JWT, string, error) {
	signer, err := s.server.KeyManager.Signer()
	if err != nil {
		log.Errorf("grpc.go: Failed to generate ID token: %v", err)
		return nil, "", oauth2.NewError(oauth2.ErrorServerError)
	}

	user, err := s.server.UserRepo.Get(nil, userID)
	if err != nil {
		log.Errorf("grpc.go: Failed to fetch user %q from repo: %v: ", userID, err)
		return nil, "", oauth2.NewError(oauth2.ErrorServerError)
	}
	claims := oidc.NewClaims(s.server.IssuerURL.String(), userID, clientID, iat, exp)
	user.AddToClaims(claims)

	if user.Admin {
		claims.Add(OtsimoUserTypeClaim, "adm")
	}

	jwt, err := jose.NewSignedJWT(claims, signer)
	if err != nil {
		log.Errorf("grpc.go: Failed to generate ID token: %v", err)
		return nil, "", oauth2.NewError(oauth2.ErrorServerError)
	}

	refreshToken, err := s.server.RefreshTokenRepo.Create(user.ID, clientID)
	if err != nil {
		log.Errorf("grpc.go: Failed to generate refresh token: %v", err)
		return nil, "", oauth2.NewError(oauth2.ErrorServerError)
	}

	return jwt, refreshToken, nil
}
예제 #4
0
파일: error_test.go 프로젝트: Tecsisa/dex
func TestRedirectAuthError(t *testing.T) {
	wantCode := http.StatusFound

	tests := []struct {
		err         error
		state       string
		redirectURL url.URL
		wantLoc     string
	}{
		{
			err:         errors.New("foobar"),
			state:       "bazinga",
			redirectURL: url.URL{Scheme: "http", Host: "server.example.com"},
			wantLoc:     "http://server.example.com?error=server_error&state=bazinga",
		},
		{
			err:         oauth2.NewError(oauth2.ErrorInvalidRequest),
			state:       "foo",
			redirectURL: url.URL{Scheme: "http", Host: "server.example.com"},
			wantLoc:     "http://server.example.com?error=invalid_request&state=foo",
		},
		{
			err:         oauth2.NewError(oauth2.ErrorUnsupportedResponseType),
			state:       "bar",
			redirectURL: url.URL{Scheme: "http", Host: "server.example.com"},
			wantLoc:     "http://server.example.com?error=unsupported_response_type&state=bar",
		},
	}

	for i, tt := range tests {
		w := httptest.NewRecorder()
		redirectAuthError(w, tt.err, tt.state, tt.redirectURL)

		if wantCode != w.Code {
			t.Errorf("case %d: incorrect HTTP status: want=%d got=%d", i, wantCode, w.Code)
		}

		wantHeader := http.Header{"Location": []string{tt.wantLoc}}
		gotHeader := w.Header()
		if !reflect.DeepEqual(wantHeader, gotHeader) {
			t.Errorf("case %d: incorrect HTTP headers: want=%#v got=%#v", i, wantHeader, gotHeader)
		}

		gotBody := w.Body.String()
		if gotBody != "" {
			t.Errorf("case %d: incorrect empty HTTP body, got=%q", i, gotBody)
		}
	}
}
예제 #5
0
파일: error.go 프로젝트: GamerockSA/dex
func writeAuthError(w http.ResponseWriter, err error, state string) {
	oerr, ok := err.(*oauth2.Error)
	if !ok {
		oerr = oauth2.NewError(oauth2.ErrorServerError)
	}
	oerr.State = state
	writeResponseWithBody(w, http.StatusBadRequest, oerr)
}
예제 #6
0
파일: error_test.go 프로젝트: Tecsisa/dex
func TestWriteAuthError(t *testing.T) {
	wantCode := http.StatusBadRequest
	wantHeader := http.Header{"Content-Type": []string{"application/json"}}
	tests := []struct {
		err      error
		state    string
		wantBody string
	}{
		{
			err:      errors.New("foobar"),
			state:    "bazinga",
			wantBody: `{"error":"server_error","state":"bazinga"}`,
		},
		{
			err:      oauth2.NewError(oauth2.ErrorInvalidRequest),
			state:    "foo",
			wantBody: `{"error":"invalid_request","state":"foo"}`,
		},
		{
			err:      oauth2.NewError(oauth2.ErrorUnsupportedResponseType),
			state:    "bar",
			wantBody: `{"error":"unsupported_response_type","state":"bar"}`,
		},
	}

	for i, tt := range tests {
		w := httptest.NewRecorder()
		writeAuthError(w, tt.err, tt.state)

		if wantCode != w.Code {
			t.Errorf("case %d: incorrect HTTP status: want=%d got=%d", i, wantCode, w.Code)
		}

		gotHeader := w.Header()
		if !reflect.DeepEqual(wantHeader, gotHeader) {
			t.Errorf("case %d: incorrect HTTP headers: want=%#v got=%#v", i, wantHeader, gotHeader)
		}

		gotBody := w.Body.String()
		if tt.wantBody != gotBody {
			t.Errorf("case %d: incorrect HTTP body: want=%q got=%q", i, tt.wantBody, gotBody)
		}
	}
}
예제 #7
0
파일: server.go 프로젝트: derekparker/dex
func (s *Server) RefreshToken(creds oidc.ClientCredentials, token string) (*jose.JWT, error) {
	ok, err := s.ClientIdentityRepo.Authenticate(creds)
	if err != nil {
		log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
		return nil, oauth2.NewError(oauth2.ErrorServerError)
	}
	if !ok {
		log.Errorf("Failed to Authenticate client %s", creds.ID)
		return nil, oauth2.NewError(oauth2.ErrorInvalidClient)
	}

	userID, err := s.RefreshTokenRepo.Verify(creds.ID, token)
	switch err {
	case nil:
		break
	case refresh.ErrorInvalidToken:
		return nil, oauth2.NewError(oauth2.ErrorInvalidRequest)
	case refresh.ErrorInvalidClientID:
		return nil, oauth2.NewError(oauth2.ErrorInvalidClient)
	default:
		return nil, oauth2.NewError(oauth2.ErrorServerError)
	}

	user, err := s.UserRepo.Get(nil, userID)
	if err != nil {
		// The error can be user.ErrorNotFound, but we are not deleting
		// user at this moment, so this shouldn't happen.
		log.Errorf("Failed to fetch user %q from repo: %v: ", userID, err)
		return nil, oauth2.NewError(oauth2.ErrorServerError)
	}

	signer, err := s.KeyManager.Signer()
	if err != nil {
		log.Errorf("Failed to refresh ID token: %v", err)
		return nil, oauth2.NewError(oauth2.ErrorServerError)
	}

	now := time.Now()
	expireAt := now.Add(session.DefaultSessionValidityWindow)

	claims := oidc.NewClaims(s.IssuerURL.String(), user.ID, creds.ID, now, expireAt)
	user.AddToClaims(claims)

	jwt, err := jose.NewSignedJWT(claims, signer)
	if err != nil {
		log.Errorf("Failed to generate ID token: %v", err)
		return nil, oauth2.NewError(oauth2.ErrorServerError)
	}

	log.Infof("New token sent: clientID=%s", creds.ID)

	return jwt, nil
}
예제 #8
0
파일: error.go 프로젝트: GamerockSA/dex
func redirectAuthError(w http.ResponseWriter, err error, state string, redirectURL url.URL) {
	oerr, ok := err.(*oauth2.Error)
	if !ok {
		oerr = oauth2.NewError(oauth2.ErrorServerError)
	}

	q := redirectURL.Query()
	q.Set("error", oerr.Type)
	q.Set("state", state)
	redirectURL.RawQuery = q.Encode()

	w.Header().Set("Location", redirectURL.String())
	w.WriteHeader(http.StatusFound)
}
예제 #9
0
파일: error.go 프로젝트: GamerockSA/dex
func writeTokenError(w http.ResponseWriter, err error, state string) {
	oerr, ok := err.(*oauth2.Error)
	if !ok {
		oerr = oauth2.NewError(oauth2.ErrorServerError)
	}
	oerr.State = state

	var status int
	switch oerr.Type {
	case oauth2.ErrorInvalidClient:
		status = http.StatusUnauthorized
		w.Header().Set("WWW-Authenticate", "Basic")
	default:
		status = http.StatusBadRequest
	}

	writeResponseWithBody(w, status, oerr)
}
예제 #10
0
func (c *uaaOAuth2Connector) Identity(cli chttp.Client) (oidc.Identity, error) {
	uaaUserInfoURL := *c.uaaBaseURL
	uaaUserInfoURL.Path = path.Join(uaaUserInfoURL.Path, "/userinfo")
	req, err := http.NewRequest("GET", uaaUserInfoURL.String(), nil)
	if err != nil {
		return oidc.Identity{}, err
	}
	resp, err := cli.Do(req)
	if err != nil {
		return oidc.Identity{}, fmt.Errorf("get: %v", err)
	}
	defer resp.Body.Close()
	switch {
	case resp.StatusCode >= 400 && resp.StatusCode < 600:
		// attempt to decode error from UAA
		var authErr uaaError
		if err := json.NewDecoder(resp.Body).Decode(&authErr); err != nil {
			return oidc.Identity{}, oauth2.NewError(oauth2.ErrorAccessDenied)
		}
		return oidc.Identity{}, authErr
	case resp.StatusCode == http.StatusOK:
	default:
		return oidc.Identity{}, fmt.Errorf("unexpected status from providor %s", resp.Status)
	}
	var user struct {
		UserID   string `json:"user_id"`
		Email    string `json:"email"`
		Name     string `json:"name"`
		UserName string `json:"user_name"`
	}
	if err := json.NewDecoder(resp.Body).Decode(&user); err != nil {
		return oidc.Identity{}, fmt.Errorf("getting user info: %v", err)
	}
	name := user.Name
	if name == "" {
		name = user.UserName
	}
	return oidc.Identity{
		ID:    user.UserID,
		Name:  name,
		Email: user.Email,
	}, nil
}
예제 #11
0
func (c *githubOAuth2Connector) Identity(cli chttp.Client) (oidc.Identity, error) {
	req, err := http.NewRequest("GET", githubAPIUserURL, nil)
	if err != nil {
		return oidc.Identity{}, err
	}
	resp, err := cli.Do(req)
	if err != nil {
		return oidc.Identity{}, fmt.Errorf("get: %v", err)
	}
	defer resp.Body.Close()
	switch {
	case resp.StatusCode >= 400 && resp.StatusCode < 600:
		// attempt to decode error from github
		var authErr githubError
		if err := json.NewDecoder(resp.Body).Decode(&authErr); err != nil {
			return oidc.Identity{}, oauth2.NewError(oauth2.ErrorAccessDenied)
		}
		return oidc.Identity{}, authErr
	case resp.StatusCode == http.StatusOK:
	default:
		return oidc.Identity{}, fmt.Errorf("unexpected status from providor %s", resp.Status)
	}
	var user struct {
		Login string `json:"login"`
		ID    int64  `json:"id"`
		Email string `json:"email"`
		Name  string `json:"name"`
	}
	if err := json.NewDecoder(resp.Body).Decode(&user); err != nil {
		return oidc.Identity{}, fmt.Errorf("getting user info: %v", err)
	}
	name := user.Name
	if name == "" {
		name = user.Login
	}
	return oidc.Identity{
		ID:    strconv.FormatInt(user.ID, 10),
		Name:  name,
		Email: user.Email,
	}, nil
}
예제 #12
0
func getAndDecode(cli chttp.Client, url string, v interface{}) error {
	req, err := http.NewRequest("GET", url, nil)
	if err != nil {
		return err
	}
	resp, err := cli.Do(req)
	if err != nil {
		return fmt.Errorf("get: %v", err)
	}
	defer resp.Body.Close()
	switch {
	case resp.StatusCode >= 400 && resp.StatusCode < 500:
		return oauth2.NewError(oauth2.ErrorAccessDenied)
	case resp.StatusCode == http.StatusOK:
	default:
		return fmt.Errorf("unexpected status from providor %s", resp.Status)
	}
	if err := json.NewDecoder(resp.Body).Decode(v); err != nil {
		return fmt.Errorf("decode body: %v", err)
	}
	return nil
}
예제 #13
0
func (c *facebookOAuth2Connector) Identity(cli chttp.Client) (oidc.Identity, error) {
	var user struct {
		ID    string `json:"id"`
		Email string `json:"email"`
		Name  string `json:"name"`
	}

	req, err := http.NewRequest("GET", facebookGraphAPIURL, nil)
	if err != nil {
		return oidc.Identity{}, err
	}
	resp, err := cli.Do(req)
	if err != nil {
		return oidc.Identity{}, fmt.Errorf("get: %v", err)
	}
	defer resp.Body.Close()

	switch {
	case resp.StatusCode >= 400 && resp.StatusCode < 600:
		var authErr facebookErr
		if err := json.NewDecoder(resp.Body).Decode(&authErr); err != nil {
			return oidc.Identity{}, oauth2.NewError(oauth2.ErrorAccessDenied)
		}
		return oidc.Identity{}, authErr
	case resp.StatusCode == http.StatusOK:
	default:
		return oidc.Identity{}, fmt.Errorf("unexpected status from providor %s", resp.Status)
	}
	if err := json.NewDecoder(resp.Body).Decode(&user); err != nil {
		return oidc.Identity{}, fmt.Errorf("decode body: %v", err)
	}

	return oidc.Identity{
		ID:    user.ID,
		Name:  user.Name,
		Email: user.Email,
	}, nil
}
예제 #14
0
파일: server.go 프로젝트: GamerockSA/dex
func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jose.JWT, string, error) {
	ok, err := s.ClientManager.Authenticate(creds)
	if err != nil {
		log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
		return nil, "", oauth2.NewError(oauth2.ErrorServerError)
	}
	if !ok {
		log.Errorf("Failed to Authenticate client %s", creds.ID)
		return nil, "", oauth2.NewError(oauth2.ErrorInvalidClient)
	}

	sessionID, err := s.SessionManager.ExchangeKey(sessionKey)
	if err != nil {
		return nil, "", oauth2.NewError(oauth2.ErrorInvalidGrant)
	}

	ses, err := s.SessionManager.Kill(sessionID)
	if err != nil {
		return nil, "", oauth2.NewError(oauth2.ErrorInvalidRequest)
	}

	if ses.ClientID != creds.ID {
		return nil, "", oauth2.NewError(oauth2.ErrorInvalidGrant)
	}

	signer, err := s.KeyManager.Signer()
	if err != nil {
		log.Errorf("Failed to generate ID token: %v", err)
		return nil, "", oauth2.NewError(oauth2.ErrorServerError)
	}

	user, err := s.UserRepo.Get(nil, ses.UserID)
	if err != nil {
		log.Errorf("Failed to fetch user %q from repo: %v: ", ses.UserID, err)
		return nil, "", oauth2.NewError(oauth2.ErrorServerError)
	}

	claims := ses.Claims(s.IssuerURL.String())
	user.AddToClaims(claims)

	s.addClaimsFromScope(claims, ses.Scope, ses.ClientID)

	jwt, err := jose.NewSignedJWT(claims, signer)
	if err != nil {
		log.Errorf("Failed to generate ID token: %v", err)
		return nil, "", oauth2.NewError(oauth2.ErrorServerError)
	}

	// Generate refresh token when 'scope' contains 'offline_access'.
	var refreshToken string

	for _, scope := range ses.Scope {
		if scope == "offline_access" {
			log.Infof("Session %s requests offline access, will generate refresh token", sessionID)

			refreshToken, err = s.RefreshTokenRepo.Create(ses.UserID, creds.ID, ses.Scope)
			switch err {
			case nil:
				break
			default:
				log.Errorf("Failed to generate refresh token: %v", err)
				return nil, "", oauth2.NewError(oauth2.ErrorServerError)
			}
			break
		}
	}

	log.Infof("Session %s token sent: clientID=%s", sessionID, creds.ID)
	return jwt, refreshToken, nil
}
예제 #15
0
func TestServerRefreshToken(t *testing.T) {
	issuerURL := url.URL{Scheme: "http", Host: "server.example.com"}

	credXXX := oidc.ClientCredentials{
		ID:     "XXX",
		Secret: "secret",
	}
	credYYY := oidc.ClientCredentials{
		ID:     "YYY",
		Secret: "secret",
	}

	signerFixture := &StaticSigner{sig: []byte("beer"), err: nil}

	tests := []struct {
		token    string
		clientID string // The client that associates with the token.
		creds    oidc.ClientCredentials
		signer   jose.Signer
		err      error
	}{
		// Everything is good.
		{
			fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
			"XXX",
			credXXX,
			signerFixture,
			nil,
		},
		// Invalid refresh token(malformatted).
		{
			"invalid-token",
			"XXX",
			credXXX,
			signerFixture,
			oauth2.NewError(oauth2.ErrorInvalidRequest),
		},
		// Invalid refresh token(invalid payload content).
		{
			fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-2"))),
			"XXX",
			credXXX,
			signerFixture,
			oauth2.NewError(oauth2.ErrorInvalidRequest),
		},
		// Invalid refresh token(invalid ID content).
		{
			fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
			"XXX",
			credXXX,
			signerFixture,
			oauth2.NewError(oauth2.ErrorInvalidRequest),
		},
		// Invalid client(client is not associated with the token).
		{
			fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
			"XXX",
			credYYY,
			signerFixture,
			oauth2.NewError(oauth2.ErrorInvalidClient),
		},
		// Invalid client(no client ID).
		{
			fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
			"XXX",
			oidc.ClientCredentials{ID: "", Secret: "aaa"},
			signerFixture,
			oauth2.NewError(oauth2.ErrorInvalidClient),
		},
		// Invalid client(no such client).
		{
			fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
			"XXX",
			oidc.ClientCredentials{ID: "AAA", Secret: "aaa"},
			signerFixture,
			oauth2.NewError(oauth2.ErrorInvalidClient),
		},
		// Invalid client(no secrets).
		{
			fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
			"XXX",
			oidc.ClientCredentials{ID: "XXX"},
			signerFixture,
			oauth2.NewError(oauth2.ErrorInvalidClient),
		},
		// Invalid client(invalid secret).
		{
			fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
			"XXX",
			oidc.ClientCredentials{ID: "XXX", Secret: "bad-secret"},
			signerFixture,
			oauth2.NewError(oauth2.ErrorInvalidClient),
		},
		// Signing operation fails.
		{
			fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
			"XXX",
			credXXX,
			&StaticSigner{sig: nil, err: errors.New("fail")},
			oauth2.NewError(oauth2.ErrorServerError),
		},
	}

	for i, tt := range tests {
		km := &StaticKeyManager{
			signer: tt.signer,
		}

		ciRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{
			oidc.ClientIdentity{Credentials: credXXX},
			oidc.ClientIdentity{Credentials: credYYY},
		})

		userRepo, err := makeNewUserRepo()
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		refreshTokenRepo, err := refreshtest.NewTestRefreshTokenRepo()
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		srv := &Server{
			IssuerURL:          issuerURL,
			KeyManager:         km,
			ClientIdentityRepo: ciRepo,
			UserRepo:           userRepo,
			RefreshTokenRepo:   refreshTokenRepo,
		}

		if _, err := refreshTokenRepo.Create("testid-1", tt.clientID); err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		jwt, err := srv.RefreshToken(tt.creds, tt.token)
		if !reflect.DeepEqual(err, tt.err) {
			t.Errorf("Case %d: expect: %v, got: %v", i, tt.err, err)
		}

		if jwt != nil {
			if string(jwt.Signature) != "beer" {
				t.Errorf("Case %d: expect signature: beer, got signature: %v", i, jwt.Signature)
			}
			claims, err := jwt.Claims()
			if err != nil {
				t.Errorf("Case %d: unexpected error: %v", i, err)
			}
			if claims["iss"] != issuerURL.String() || claims["sub"] != "testid-1" || claims["aud"] != "XXX" {
				t.Errorf("Case %d: invalid claims: %v", i, claims)
			}
		}
	}

	// Test that we should return error when user cannot be found after
	// verifying the token.
	km := &StaticKeyManager{
		signer: signerFixture,
	}

	ciRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{
		oidc.ClientIdentity{Credentials: credXXX},
		oidc.ClientIdentity{Credentials: credYYY},
	})

	userRepo, err := makeNewUserRepo()
	if err != nil {
		t.Fatalf("Unexpected error: %v", err)
	}

	// Create a user that will be removed later.
	if err := userRepo.Create(nil, user.User{
		ID:    "testid-2",
		Email: "*****@*****.**",
	}); err != nil {
		t.Fatalf("Unexpected error: %v", err)
	}

	refreshTokenRepo, err := refreshtest.NewTestRefreshTokenRepo()
	if err != nil {
		t.Fatalf("Unexpected error: %v", err)
	}

	srv := &Server{
		IssuerURL:          issuerURL,
		KeyManager:         km,
		ClientIdentityRepo: ciRepo,
		UserRepo:           userRepo,
		RefreshTokenRepo:   refreshTokenRepo,
	}

	if _, err := refreshTokenRepo.Create("testid-2", credXXX.ID); err != nil {
		t.Fatalf("Unexpected error: %v", err)
	}

	// Recreate the user repo to remove the user we created.
	userRepo, err = makeNewUserRepo()
	if err != nil {
		t.Fatalf("Unexpected error: %v", err)
	}
	srv.UserRepo = userRepo

	_, err = srv.RefreshToken(credXXX, fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))))
	if !reflect.DeepEqual(err, oauth2.NewError(oauth2.ErrorServerError)) {
		t.Errorf("Expect: %v, got: %v", oauth2.NewError(oauth2.ErrorServerError), err)
	}
}
예제 #16
0
func TestServerTokenFail(t *testing.T) {
	issuerURL := url.URL{Scheme: "http", Host: "server.example.com"}
	keyFixture := "goodkey"
	ccFixture := oidc.ClientCredentials{
		ID:     "XXX",
		Secret: "secrete",
	}
	signerFixture := &StaticSigner{sig: []byte("beer"), err: nil}

	tests := []struct {
		signer       jose.Signer
		argCC        oidc.ClientCredentials
		argKey       string
		err          error
		scope        []string
		refreshToken string
	}{
		// control test case to make sure fixtures check out
		{
			signer:       signerFixture,
			argCC:        ccFixture,
			argKey:       keyFixture,
			scope:        []string{"openid", "offline_access"},
			refreshToken: fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
		},

		// no 'offline_access' in 'scope', should get empty refresh token
		{
			signer: signerFixture,
			argCC:  ccFixture,
			argKey: keyFixture,
			scope:  []string{"openid"},
		},

		// unrecognized key
		{
			signer: signerFixture,
			argCC:  ccFixture,
			argKey: "foo",
			err:    oauth2.NewError(oauth2.ErrorInvalidGrant),
			scope:  []string{"openid", "offline_access"},
		},

		// unrecognized client
		{
			signer: signerFixture,
			argCC:  oidc.ClientCredentials{ID: "YYY"},
			argKey: keyFixture,
			err:    oauth2.NewError(oauth2.ErrorInvalidClient),
			scope:  []string{"openid", "offline_access"},
		},

		// signing operation fails
		{
			signer: &StaticSigner{sig: nil, err: errors.New("fail")},
			argCC:  ccFixture,
			argKey: keyFixture,
			err:    oauth2.NewError(oauth2.ErrorServerError),
			scope:  []string{"openid", "offline_access"},
		},
	}

	for i, tt := range tests {
		sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
		sm.GenerateCode = func() (string, error) { return keyFixture, nil }

		sessionID, err := sm.NewSession("connector_id", ccFixture.ID, "bogus", url.URL{}, "", false, tt.scope)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		_, err = sm.AttachRemoteIdentity(sessionID, oidc.Identity{})
		if err != nil {
			t.Errorf("case %d: unexpected error: %v", i, err)
			continue
		}
		km := &StaticKeyManager{
			signer: tt.signer,
		}
		ciRepo := client.NewClientIdentityRepo([]oidc.ClientIdentity{
			oidc.ClientIdentity{Credentials: ccFixture},
		})

		_, err = sm.AttachUser(sessionID, "testid-1")
		if err != nil {
			t.Fatalf("case %d: unexpected error: %v", i, err)
		}

		userRepo, err := makeNewUserRepo()
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		refreshTokenRepo, err := refreshtest.NewTestRefreshTokenRepo()
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		srv := &Server{
			IssuerURL:          issuerURL,
			KeyManager:         km,
			SessionManager:     sm,
			ClientIdentityRepo: ciRepo,
			UserRepo:           userRepo,
			RefreshTokenRepo:   refreshTokenRepo,
		}

		_, err = sm.NewSessionKey(sessionID)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		jwt, token, err := srv.CodeToken(tt.argCC, tt.argKey)
		if token != tt.refreshToken {
			fmt.Printf("case %d: expect refresh token %q, got %q\n", i, tt.refreshToken, token)
			t.Fatalf("case %d: expect refresh token %q, got %q", i, tt.refreshToken, token)
			panic("")
		}
		if !reflect.DeepEqual(err, tt.err) {
			t.Errorf("case %d: expect %v, got %v", i, tt.err, err)
		}
		if err == nil && jwt == nil {
			t.Errorf("case %d: got nil JWT", i)
		}
		if err != nil && jwt != nil {
			t.Errorf("case %d: got non-nil JWT %v", i, jwt)
		}
	}
}
예제 #17
0
파일: server_test.go 프로젝트: Tecsisa/dex
func TestServerRefreshToken(t *testing.T) {

	clientB := client.Client{
		Credentials: oidc.ClientCredentials{
			ID:     "example2.com",
			Secret: clientTestSecret,
		},
		Metadata: oidc.ClientMetadata{
			RedirectURIs: []url.URL{
				url.URL{Scheme: "https", Host: "example2.com", Path: "one/two/three"},
			},
		},
	}
	signerFixture := &StaticSigner{sig: []byte("beer"), err: nil}

	// NOTE(ericchiang): These tests assume that the database ID of the first
	// refresh token will be "1".
	tests := []struct {
		token                string
		expectedRefreshToken string
		clientID             string // The client that associates with the token.
		creds                oidc.ClientCredentials
		signer               jose.Signer
		createScopes         []string
		refreshScopes        []string
		expectedAud          []string
		err                  error
	}{
		// Everything is good.
		{
			token:                getRefreshTokenEncoded("1", "refresh-1"),
			expectedRefreshToken: getRefreshTokenEncoded("1", "refresh-2"),
			clientID:             testClientID,
			creds:                testClientCredentials,
			signer:               signerFixture,
			createScopes:         []string{"openid", "profile"},
			refreshScopes:        []string{"openid", "profile"},
		},
		// Asking for a scope not originally granted to you.
		{
			token:         getRefreshTokenEncoded("1", "refresh-1"),
			clientID:      testClientID,
			creds:         testClientCredentials,
			signer:        signerFixture,
			createScopes:  []string{"openid", "profile"},
			refreshScopes: []string{"openid", "profile", "extra_scope"},
			err:           oauth2.NewError(oauth2.ErrorInvalidRequest),
		},
		// Invalid refresh token(malformatted).
		{
			token:         "invalid-token",
			clientID:      testClientID,
			creds:         testClientCredentials,
			signer:        signerFixture,
			createScopes:  []string{"openid", "profile"},
			refreshScopes: []string{"openid", "profile"},
			err:           oauth2.NewError(oauth2.ErrorInvalidRequest),
		},
		// Invalid refresh token(invalid payload content).
		{
			token:         getRefreshTokenEncoded("1", "refresh-2"),
			clientID:      testClientID,
			creds:         testClientCredentials,
			signer:        signerFixture,
			createScopes:  []string{"openid", "profile"},
			refreshScopes: []string{"openid", "profile"},
			err:           oauth2.NewError(oauth2.ErrorInvalidRequest),
		},
		// Invalid refresh token(invalid ID content).
		{
			token:         getRefreshTokenEncoded("0", "refresh-1"),
			clientID:      testClientID,
			creds:         testClientCredentials,
			signer:        signerFixture,
			createScopes:  []string{"openid", "profile"},
			refreshScopes: []string{"openid", "profile"},
			err:           oauth2.NewError(oauth2.ErrorInvalidRequest),
		},
		// Invalid client(client is not associated with the token).
		{
			token:         getRefreshTokenEncoded("1", "refresh-1"),
			clientID:      testClientID,
			creds:         clientB.Credentials,
			signer:        signerFixture,
			createScopes:  []string{"openid", "profile"},
			refreshScopes: []string{"openid", "profile"},
			err:           oauth2.NewError(oauth2.ErrorInvalidClient),
		},
		// Invalid client(no client ID).
		{
			token:         getRefreshTokenEncoded("1", "refresh-1"),
			clientID:      testClientID,
			creds:         oidc.ClientCredentials{ID: "", Secret: "aaa"},
			signer:        signerFixture,
			createScopes:  []string{"openid", "profile"},
			refreshScopes: []string{"openid", "profile"},
			err:           oauth2.NewError(oauth2.ErrorInvalidClient),
		},
		// Invalid client(no such client).
		{
			token:         getRefreshTokenEncoded("1", "refresh-1"),
			clientID:      testClientID,
			creds:         oidc.ClientCredentials{ID: "AAA", Secret: "aaa"},
			signer:        signerFixture,
			createScopes:  []string{"openid", "profile"},
			refreshScopes: []string{"openid", "profile"},
			err:           oauth2.NewError(oauth2.ErrorInvalidClient),
		},
		// Invalid client(no secrets).
		{
			token:         getRefreshTokenEncoded("1", "refresh-1"),
			clientID:      testClientID,
			creds:         oidc.ClientCredentials{ID: testClientID},
			signer:        signerFixture,
			createScopes:  []string{"openid", "profile"},
			refreshScopes: []string{"openid", "profile"},
			err:           oauth2.NewError(oauth2.ErrorInvalidClient),
		},
		// Invalid client(invalid secret).
		{
			token:         getRefreshTokenEncoded("1", "refresh-1"),
			clientID:      testClientID,
			creds:         oidc.ClientCredentials{ID: "bad-id", Secret: "bad-secret"},
			signer:        signerFixture,
			createScopes:  []string{"openid", "profile"},
			refreshScopes: []string{"openid", "profile"},
			err:           oauth2.NewError(oauth2.ErrorInvalidClient),
		},
		// Signing operation fails.
		{
			token:         getRefreshTokenEncoded("1", "refresh-1"),
			clientID:      testClientID,
			creds:         testClientCredentials,
			signer:        &StaticSigner{sig: nil, err: errors.New("fail")},
			createScopes:  []string{"openid", "profile"},
			refreshScopes: []string{"openid", "profile"},
			err:           oauth2.NewError(oauth2.ErrorServerError),
		},
		// Valid Cross-Client
		{
			token:                getRefreshTokenEncoded("1", "refresh-1"),
			expectedRefreshToken: getRefreshTokenEncoded("1", "refresh-2"),
			clientID:             "client_a",
			creds: oidc.ClientCredentials{
				ID: "client_a",
				Secret: base64.URLEncoding.EncodeToString(
					[]byte("client_a_secret")),
			},
			signer:        signerFixture,
			createScopes:  []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b"},
			refreshScopes: []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b"},
			expectedAud:   []string{"client_b"},
		},
		// Valid Cross-Client - but this time we leave out the scopes in the
		// refresh request, which should result in the original stored scopes
		// being used.
		{
			token:                getRefreshTokenEncoded("1", "refresh-1"),
			expectedRefreshToken: getRefreshTokenEncoded("1", "refresh-2"),
			clientID:             "client_a",
			creds: oidc.ClientCredentials{
				ID: "client_a",
				Secret: base64.URLEncoding.EncodeToString(
					[]byte("client_a_secret")),
			},
			signer:        signerFixture,
			createScopes:  []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b"},
			refreshScopes: []string{},
			expectedAud:   []string{"client_b"},
		},
		// Valid Cross-Client - asking for fewer scopes than originally used
		// when creating the refresh token, which is ok.
		{
			token:                getRefreshTokenEncoded("1", "refresh-1"),
			expectedRefreshToken: getRefreshTokenEncoded("1", "refresh-2"),
			clientID:             "client_a",
			creds: oidc.ClientCredentials{
				ID: "client_a",
				Secret: base64.URLEncoding.EncodeToString(
					[]byte("client_a_secret")),
			},
			signer:        signerFixture,
			createScopes:  []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b", scope.ScopeGoogleCrossClient + "client_c"},
			refreshScopes: []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b"},
			expectedAud:   []string{"client_b"},
		},
		// Valid Cross-Client - asking for multiple clients in the audience.
		{
			token:                getRefreshTokenEncoded("1", "refresh-1"),
			expectedRefreshToken: getRefreshTokenEncoded("1", "refresh-2"),
			clientID:             "client_a",
			creds: oidc.ClientCredentials{
				ID: "client_a",
				Secret: base64.URLEncoding.EncodeToString(
					[]byte("client_a_secret")),
			},
			signer:        signerFixture,
			createScopes:  []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b", scope.ScopeGoogleCrossClient + "client_c"},
			refreshScopes: []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b", scope.ScopeGoogleCrossClient + "client_c"},
			expectedAud:   []string{"client_b", "client_c"},
		},
		// Invalid Cross-Client - didn't orignally request cross-client when
		// refresh token was created.
		{
			token:    getRefreshTokenEncoded("1", "refresh-1"),
			clientID: "client_a",
			creds: oidc.ClientCredentials{
				ID: "client_a",
				Secret: base64.URLEncoding.EncodeToString(
					[]byte("client_a_secret")),
			},
			signer:        signerFixture,
			createScopes:  []string{"openid", "profile"},
			refreshScopes: []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b"},
			err:           oauth2.NewError(oauth2.ErrorInvalidRequest),
		},
	}

	for i, tt := range tests {
		km := &StaticKeyManager{
			signer: tt.signer,
		}
		f, err := makeCrossClientTestFixtures()
		if err != nil {
			t.Fatalf("error making test fixtures: %v", err)
		}
		f.srv.RefreshTokenRepo = refreshtest.NewTestRefreshTokenRepo()
		f.srv.KeyManager = km
		_, err = f.clientRepo.New(nil, clientB)
		if err != nil {
			t.Errorf("case %d: error creating other client: %v", i, err)
		}

		if _, err := f.srv.RefreshTokenRepo.Create(testUserID1, tt.clientID, "", tt.createScopes); err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		jwt, refreshToken, expiresIn, err := f.srv.RefreshToken(tt.creds, tt.refreshScopes, tt.token)
		if !reflect.DeepEqual(err, tt.err) {
			t.Errorf("Case %d: expect: %v, got: %v", i, tt.err, err)
		}

		if jwt != nil {
			if string(jwt.Signature) != "beer" {
				t.Errorf("Case %d: expect signature: beer, got signature: %v", i, jwt.Signature)
			}
			claims, err := jwt.Claims()
			if err != nil {
				t.Errorf("Case %d: unexpected error: %v", i, err)
			}

			var expectedAud interface{}
			if tt.expectedAud == nil {
				expectedAud = testClientID
			} else if len(tt.expectedAud) == 1 {
				expectedAud = tt.expectedAud[0]
			} else {
				expectedAud = tt.expectedAud
			}

			if claims["iss"] != testIssuerURL.String() {
				t.Errorf("Case %d: want=%v, got=%v", i,
					testIssuerURL.String(), claims["iss"])
			}
			if claims["sub"] != testUserID1 {
				t.Errorf("Case %d: want=%v, got=%v", i,
					testUserID1, claims["sub"])
			}
			if diff := pretty.Compare(claims["aud"], expectedAud); diff != "" {
				t.Errorf("Case %d: want=%v, got=%v", i,
					expectedAud, claims["aud"])
			}
		}

		if diff := pretty.Compare(refreshToken, tt.expectedRefreshToken); diff != "" {
			t.Errorf("Case %d: want=%v, got=%v", i, tt.expectedRefreshToken, refreshToken)
		}

		if err == nil && expiresIn.IsZero() {
			t.Errorf("case %d: got zero expiration time %v", i, expiresIn)
		}
	}
}
예제 #18
0
파일: server_test.go 프로젝트: Tecsisa/dex
func TestServerTokenFail(t *testing.T) {
	keyFixture := "goodkey"

	signerFixture := &StaticSigner{sig: []byte("beer"), err: nil}

	tests := []struct {
		signer       jose.Signer
		argCC        oidc.ClientCredentials
		argKey       string
		err          error
		scope        []string
		refreshToken string
	}{
		// control test case to make sure fixtures check out
		{
			// NOTE(ericchiang): This test assumes that the database ID of the first
			// refresh token will be "1".
			signer:       signerFixture,
			argCC:        testClientCredentials,
			argKey:       keyFixture,
			scope:        []string{"openid", "offline_access"},
			refreshToken: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
		},

		// no 'offline_access' in 'scope', should get empty refresh token
		{
			signer: signerFixture,
			argCC:  testClientCredentials,
			argKey: keyFixture,
			scope:  []string{"openid"},
		},

		// unrecognized key
		{
			signer: signerFixture,
			argCC:  testClientCredentials,
			argKey: "foo",
			err:    oauth2.NewError(oauth2.ErrorInvalidGrant),
			scope:  []string{"openid", "offline_access"},
		},

		// unrecognized client
		{
			signer: signerFixture,
			argCC:  oidc.ClientCredentials{ID: "YYY"},
			argKey: keyFixture,
			err:    oauth2.NewError(oauth2.ErrorInvalidClient),
			scope:  []string{"openid", "offline_access"},
		},

		// signing operation fails
		{
			signer: &StaticSigner{sig: nil, err: errors.New("fail")},
			argCC:  testClientCredentials,
			argKey: keyFixture,
			err:    oauth2.NewError(oauth2.ErrorServerError),
			scope:  []string{"openid", "offline_access"},
		},
	}

	for i, tt := range tests {

		f, err := makeTestFixtures()
		if err != nil {
			t.Fatalf("error making test fixtures: %v", err)
		}
		sm := f.sessionManager
		sm.GenerateCode = func() (string, error) { return keyFixture, nil }
		f.srv.RefreshTokenRepo = refreshtest.NewTestRefreshTokenRepo()
		f.srv.KeyManager = &StaticKeyManager{
			signer: tt.signer,
		}

		sessionID, err := sm.NewSession(testConnectorID1, testClientID, "bogus", url.URL{}, "", false, tt.scope)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		_, err = sm.AttachRemoteIdentity(sessionID, oidc.Identity{})
		if err != nil {
			t.Errorf("case %d: unexpected error: %v", i, err)
			continue
		}
		_, err = sm.AttachUser(sessionID, testUserID1)
		if err != nil {
			t.Fatalf("case %d: unexpected error: %v", i, err)
		}

		_, err = sm.NewSessionKey(sessionID)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		jwt, token, expiresAt, err := f.srv.CodeToken(tt.argCC, tt.argKey)
		if token != tt.refreshToken {
			fmt.Printf("case %d: expect refresh token %q, got %q\n", i, tt.refreshToken, token)
			t.Fatalf("case %d: expect refresh token %q, got %q", i, tt.refreshToken, token)
			panic("")
		}
		if !reflect.DeepEqual(err, tt.err) {
			t.Errorf("case %d: expect %v, got %v", i, tt.err, err)
		}
		if err == nil && jwt == nil {
			t.Errorf("case %d: got nil JWT", i)
		}
		if err != nil && jwt != nil {
			t.Errorf("case %d: got non-nil JWT %v", i, jwt)
		}
		if err == nil && expiresAt.IsZero() {
			t.Errorf("case %d: got zero expiration time %v", i, expiresAt)
		}
	}
}
예제 #19
0
파일: http.go 프로젝트: Tecsisa/dex
func validateScopes(srv OIDCServer, clientID string, scopes []string) error {
	foundOpenIDScope := false
	for i, curScope := range scopes {
		if i > 0 && curScope == scopes[i-1] {
			err := oauth2.NewError(oauth2.ErrorInvalidRequest)
			err.Description = fmt.Sprintf(
				"Duplicate scopes are not allowed: %q",
				curScope)
			return err
		}

		switch {
		case strings.HasPrefix(curScope, scope.ScopeGoogleCrossClient):
			otherClient := curScope[len(scope.ScopeGoogleCrossClient):]
			var allowed bool
			var err error
			if otherClient == clientID {
				allowed = true
			} else {
				allowed, err = srv.CrossClientAuthAllowed(clientID, otherClient)
				if err != nil {
					return err
				}
			}

			if !allowed {
				err := oauth2.NewError(oauth2.ErrorInvalidRequest)
				err.Description = fmt.Sprintf(
					"%q is not authorized to perform cross-client requests for %q",
					clientID, otherClient)
				return err
			}
		case curScope == "openid":
			foundOpenIDScope = true
		case curScope == "profile":
		case curScope == "email":
		case curScope == scope.ScopeGroups:
		case curScope == "offline_access":
			// According to the spec, for offline_access scope, the client must
			// use a response_type value that would result in an Authorization
			// Code.  Currently oauth2.ResponseTypeCode is the only supported
			// response type, and it's been checked above, so we don't need to
			// check it again here.
			//
			// TODO(yifan): Verify that 'consent' should be in 'prompt'.
		default:
			// Reject all other scopes.
			err := oauth2.NewError(oauth2.ErrorInvalidRequest)
			err.Description = fmt.Sprintf("%q is not a recognized scope", curScope)
			return err
		}
	}

	if !foundOpenIDScope {
		log.Errorf("Invalid auth request: missing 'openid' in 'scope'")
		err := oauth2.NewError(oauth2.ErrorInvalidRequest)
		err.Description = "Invalid auth request: missing 'openid' in 'scope'"
		return err
	}
	return nil
}
예제 #20
0
파일: http.go 프로젝트: set321go/dex
func handleTokenFunc(srv OIDCServer) http.HandlerFunc {
	return func(w http.ResponseWriter, r *http.Request) {
		if r.Method != "POST" {
			w.Header().Set("Allow", "POST")
			phttp.WriteError(w, http.StatusMethodNotAllowed, fmt.Sprintf("POST only acceptable method"))
			return
		}

		err := r.ParseForm()
		if err != nil {
			log.Errorf("error parsing request: %v", err)
			writeTokenError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), "")
			return
		}

		state := r.PostForm.Get("state")

		user, password, ok := r.BasicAuth()
		if !ok {
			log.Errorf("error parsing basic auth")
			writeTokenError(w, oauth2.NewError(oauth2.ErrorInvalidClient), state)
			return
		}

		creds := oidc.ClientCredentials{ID: user, Secret: password}

		var jwt *jose.JWT
		var refreshToken string
		grantType := r.PostForm.Get("grant_type")

		switch grantType {
		case oauth2.GrantTypeAuthCode:
			code := r.PostForm.Get("code")
			if code == "" {
				log.Errorf("missing code param")
				writeTokenError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), state)
				return
			}
			jwt, refreshToken, err = srv.CodeToken(creds, code)
			if err != nil {
				log.Errorf("couldn't exchange code for token: %v", err)
				writeTokenError(w, err, state)
				return
			}
		case oauth2.GrantTypeClientCreds:
			jwt, err = srv.ClientCredsToken(creds)
			if err != nil {
				log.Errorf("couldn't creds for token: %v", err)
				writeTokenError(w, err, state)
				return
			}
		case oauth2.GrantTypeRefreshToken:
			token := r.PostForm.Get("refresh_token")
			if token == "" {
				writeTokenError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), state)
				return
			}
			jwt, err = srv.RefreshToken(creds, token)
			if err != nil {
				writeTokenError(w, err, state)
				return
			}
		default:
			log.Errorf("unsupported grant: %v", grantType)
			writeTokenError(w, oauth2.NewError(oauth2.ErrorUnsupportedGrantType), state)
			return
		}

		t := oAuth2Token{
			AccessToken:  jwt.Encode(),
			IDToken:      jwt.Encode(),
			TokenType:    "bearer",
			RefreshToken: refreshToken,
		}

		b, err := json.Marshal(t)
		if err != nil {
			log.Errorf("Failed marshaling %#v to JSON: %v", t, err)
			writeTokenError(w, oauth2.NewError(oauth2.ErrorServerError), state)
			return
		}

		w.Header().Set("Content-Type", "application/json")
		w.WriteHeader(http.StatusOK)
		w.Write(b)
	}
}
예제 #21
0
파일: http.go 프로젝트: set321go/dex
func handleAuthFunc(srv OIDCServer, idpcs []connector.Connector, tpl *template.Template, registrationEnabled bool) http.HandlerFunc {
	idx := makeConnectorMap(idpcs)
	return func(w http.ResponseWriter, r *http.Request) {
		if r.Method != "GET" {
			w.Header().Set("Allow", "GET")
			phttp.WriteError(w, http.StatusMethodNotAllowed, "GET only acceptable method")
			return
		}

		q := r.URL.Query()
		register := q.Get("register") == "1" && registrationEnabled
		e := q.Get("error")
		if e != "" {
			sessionKey := q.Get("state")
			if err := srv.KillSession(sessionKey); err != nil {
				log.Errorf("Failed killing sessionKey %q: %v", sessionKey, err)
			}
			renderLoginPage(w, r, srv, idpcs, register, tpl)
			return
		}

		connectorID := q.Get("connector_id")
		idpc, ok := idx[connectorID]
		if !ok {
			renderLoginPage(w, r, srv, idpcs, register, tpl)
			return
		}

		acr, err := oauth2.ParseAuthCodeRequest(q)
		if err != nil {
			log.Errorf("Invalid auth request: %v", err)
			writeAuthError(w, err, acr.State)
			return
		}

		cm, err := srv.ClientMetadata(acr.ClientID)
		if err != nil {
			log.Errorf("Failed fetching client %q from repo: %v", acr.ClientID, err)
			writeAuthError(w, oauth2.NewError(oauth2.ErrorServerError), acr.State)
			return
		}
		if cm == nil {
			log.Errorf("Client %q not found", acr.ClientID)
			writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State)
			return
		}

		if len(cm.RedirectURLs) == 0 {
			log.Errorf("Client %q has no redirect URLs", acr.ClientID)
			writeAuthError(w, oauth2.NewError(oauth2.ErrorServerError), acr.State)
			return
		}

		redirectURL, err := client.ValidRedirectURL(acr.RedirectURL, cm.RedirectURLs)
		if err != nil {
			switch err {
			case (client.ErrorCantChooseRedirectURL):
				log.Errorf("Request must provide redirect URL as client %q has registered many", acr.ClientID)
				writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State)
				return
			case (client.ErrorInvalidRedirectURL):
				log.Errorf("Request provided unregistered redirect URL: %s", acr.RedirectURL)
				writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State)
				return
			case (client.ErrorNoValidRedirectURLs):
				log.Errorf("There are no registered URLs for the requested client: %s", acr.RedirectURL)
				writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State)
				return
			}
		}

		if acr.ResponseType != oauth2.ResponseTypeCode {
			log.Errorf("unexpected ResponseType: %v: ", acr.ResponseType)
			redirectAuthError(w, oauth2.NewError(oauth2.ErrorUnsupportedResponseType), acr.State, redirectURL)
			return
		}

		// Check scopes.
		var scopes []string
		foundOpenIDScope := false
		for _, scope := range acr.Scope {
			switch scope {
			case "openid":
				foundOpenIDScope = true
				scopes = append(scopes, scope)
			case "offline_access":
				// According to the spec, for offline_access scope, the client must
				// use a response_type value that would result in an Authorization Code.
				// Currently oauth2.ResponseTypeCode is the only supported response type,
				// and it's been checked above, so we don't need to check it again here.
				//
				// TODO(yifan): Verify that 'consent' should be in 'prompt'.
				scopes = append(scopes, scope)
			default:
				// Pass all other scopes.
				scopes = append(scopes, scope)
			}
		}

		if !foundOpenIDScope {
			log.Errorf("Invalid auth request: missing 'openid' in 'scope'")
			writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State)
			return
		}

		nonce := q.Get("nonce")

		key, err := srv.NewSession(connectorID, acr.ClientID, acr.State, redirectURL, nonce, register, acr.Scope)
		if err != nil {
			log.Errorf("Error creating new session: %v: ", err)
			redirectAuthError(w, err, acr.State, redirectURL)
			return
		}

		if register {
			_, ok := idpc.(*connector.LocalConnector)
			if ok {
				q := url.Values{}
				q.Set("code", key)
				ru := httpPathRegister + "?" + q.Encode()
				w.Header().Set("Location", ru)
				w.WriteHeader(http.StatusFound)
				return
			}
		}

		var p string
		if register {
			p = "select_account consent"
		}
		if shouldReprompt(r) || register {
			p = "select_account"
		}
		lu, err := idpc.LoginURL(key, p)
		if err != nil {
			log.Errorf("Connector.LoginURL failed: %v", err)
			redirectAuthError(w, err, acr.State, redirectURL)
			return
		}

		http.SetCookie(w, createLastSeenCookie())
		w.Header().Set("Location", lu)
		w.WriteHeader(http.StatusFound)
		return
	}
}
예제 #22
0
파일: server_test.go 프로젝트: fnordahl/dex
func TestServerRefreshToken(t *testing.T) {
	issuerURL := url.URL{Scheme: "http", Host: "server.example.com"}
	clientA := client.Client{
		Credentials: oidc.ClientCredentials{
			ID:     testClientID,
			Secret: clientTestSecret,
		},
		Metadata: oidc.ClientMetadata{
			RedirectURIs: []url.URL{
				url.URL{Scheme: "https", Host: "client.example.com", Path: "one/two/three"},
			},
		},
	}
	clientB := client.Client{
		Credentials: oidc.ClientCredentials{
			ID:     "example2.com",
			Secret: clientTestSecret,
		},
		Metadata: oidc.ClientMetadata{
			RedirectURIs: []url.URL{
				url.URL{Scheme: "https", Host: "example2.com", Path: "one/two/three"},
			},
		},
	}

	signerFixture := &StaticSigner{sig: []byte("beer"), err: nil}

	// NOTE(ericchiang): These tests assume that the database ID of the first
	// refresh token will be "1".
	tests := []struct {
		token    string
		clientID string // The client that associates with the token.
		creds    oidc.ClientCredentials
		signer   jose.Signer
		err      error
	}{
		// Everything is good.
		{
			fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
			clientA.Credentials.ID,
			clientA.Credentials,
			signerFixture,
			nil,
		},
		// Invalid refresh token(malformatted).
		{
			"invalid-token",
			clientA.Credentials.ID,
			clientA.Credentials,
			signerFixture,
			oauth2.NewError(oauth2.ErrorInvalidRequest),
		},
		// Invalid refresh token(invalid payload content).
		{
			fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-2"))),
			clientA.Credentials.ID,
			clientA.Credentials,
			signerFixture,
			oauth2.NewError(oauth2.ErrorInvalidRequest),
		},
		// Invalid refresh token(invalid ID content).
		{
			fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
			clientA.Credentials.ID,
			clientA.Credentials,
			signerFixture,
			oauth2.NewError(oauth2.ErrorInvalidRequest),
		},
		// Invalid client(client is not associated with the token).
		{
			fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
			clientA.Credentials.ID,
			clientB.Credentials,
			signerFixture,
			oauth2.NewError(oauth2.ErrorInvalidClient),
		},
		// Invalid client(no client ID).
		{
			fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
			clientA.Credentials.ID,
			oidc.ClientCredentials{ID: "", Secret: "aaa"},
			signerFixture,
			oauth2.NewError(oauth2.ErrorInvalidClient),
		},
		// Invalid client(no such client).
		{
			fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
			clientA.Credentials.ID,
			oidc.ClientCredentials{ID: "AAA", Secret: "aaa"},
			signerFixture,
			oauth2.NewError(oauth2.ErrorInvalidClient),
		},
		// Invalid client(no secrets).
		{
			fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
			clientA.Credentials.ID,
			oidc.ClientCredentials{ID: testClientID},
			signerFixture,
			oauth2.NewError(oauth2.ErrorInvalidClient),
		},
		// Invalid client(invalid secret).
		{
			fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
			clientA.Credentials.ID,
			oidc.ClientCredentials{ID: "bad-id", Secret: "bad-secret"},
			signerFixture,
			oauth2.NewError(oauth2.ErrorInvalidClient),
		},
		// Signing operation fails.
		{
			fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
			clientA.Credentials.ID,
			clientA.Credentials,
			&StaticSigner{sig: nil, err: errors.New("fail")},
			oauth2.NewError(oauth2.ErrorServerError),
		},
	}

	for i, tt := range tests {
		km := &StaticKeyManager{
			signer: tt.signer,
		}

		clients := []client.Client{
			clientA,
			clientB,
		}

		clientIDGenerator := func(hostport string) (string, error) {
			return hostport, nil
		}
		secGen := func() ([]byte, error) {
			return []byte("secret"), nil
		}
		dbm := db.NewMemDB()
		clientRepo := db.NewClientRepo(dbm)
		clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
		if err != nil {
			t.Fatalf("Failed to create client identity manager: %v", err)
		}
		userRepo, err := makeNewUserRepo()
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo()

		srv := &Server{
			IssuerURL:        issuerURL,
			KeyManager:       km,
			ClientRepo:       clientRepo,
			ClientManager:    clientManager,
			UserRepo:         userRepo,
			RefreshTokenRepo: refreshTokenRepo,
		}

		if _, err := refreshTokenRepo.Create("testid-1", tt.clientID); err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		jwt, err := srv.RefreshToken(tt.creds, tt.token)
		if !reflect.DeepEqual(err, tt.err) {
			t.Errorf("Case %d: expect: %v, got: %v", i, tt.err, err)
		}

		if jwt != nil {
			if string(jwt.Signature) != "beer" {
				t.Errorf("Case %d: expect signature: beer, got signature: %v", i, jwt.Signature)
			}
			claims, err := jwt.Claims()
			if err != nil {
				t.Errorf("Case %d: unexpected error: %v", i, err)
			}
			if claims["iss"] != issuerURL.String() || claims["sub"] != "testid-1" || claims["aud"] != testClientID {
				t.Errorf("Case %d: invalid claims: %v", i, claims)
			}
		}
	}

	// Test that we should return error when user cannot be found after
	// verifying the token.
	km := &StaticKeyManager{
		signer: signerFixture,
	}

	clients := []client.Client{
		clientA,
		clientB,
	}
	clientIDGenerator := func(hostport string) (string, error) {
		return hostport, nil
	}
	secGen := func() ([]byte, error) {
		return []byte("secret"), nil
	}
	dbm := db.NewMemDB()
	clientRepo := db.NewClientRepo(dbm)
	clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
	if err != nil {
		t.Fatalf("Failed to create client identity manager: %v", err)
	}
	userRepo, err := makeNewUserRepo()
	if err != nil {
		t.Fatalf("Unexpected error: %v", err)
	}

	// Create a user that will be removed later.
	if err := userRepo.Create(nil, user.User{
		ID:    "testid-2",
		Email: "*****@*****.**",
	}); err != nil {
		t.Fatalf("Unexpected error: %v", err)
	}

	refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo()

	srv := &Server{
		IssuerURL:        issuerURL,
		KeyManager:       km,
		ClientRepo:       clientRepo,
		ClientManager:    clientManager,
		UserRepo:         userRepo,
		RefreshTokenRepo: refreshTokenRepo,
	}

	if _, err := refreshTokenRepo.Create("testid-2", clientA.Credentials.ID); err != nil {
		t.Fatalf("Unexpected error: %v", err)
	}

	// Recreate the user repo to remove the user we created.
	userRepo, err = makeNewUserRepo()
	if err != nil {
		t.Fatalf("Unexpected error: %v", err)
	}
	srv.UserRepo = userRepo

	_, err = srv.RefreshToken(clientA.Credentials, fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))))
	if !reflect.DeepEqual(err, oauth2.NewError(oauth2.ErrorServerError)) {
		t.Errorf("Expect: %v, got: %v", oauth2.NewError(oauth2.ErrorServerError), err)
	}
}
예제 #23
0
파일: server_test.go 프로젝트: fnordahl/dex
func TestServerTokenFail(t *testing.T) {
	issuerURL := url.URL{Scheme: "http", Host: "server.example.com"}
	keyFixture := "goodkey"
	ccFixture := oidc.ClientCredentials{
		ID:     testClientID,
		Secret: clientTestSecret,
	}
	signerFixture := &StaticSigner{sig: []byte("beer"), err: nil}

	tests := []struct {
		signer       jose.Signer
		argCC        oidc.ClientCredentials
		argKey       string
		err          error
		scope        []string
		refreshToken string
	}{
		// control test case to make sure fixtures check out
		{
			// NOTE(ericchiang): This test assumes that the database ID of the first
			// refresh token will be "1".
			signer:       signerFixture,
			argCC:        ccFixture,
			argKey:       keyFixture,
			scope:        []string{"openid", "offline_access"},
			refreshToken: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
		},

		// no 'offline_access' in 'scope', should get empty refresh token
		{
			signer: signerFixture,
			argCC:  ccFixture,
			argKey: keyFixture,
			scope:  []string{"openid"},
		},

		// unrecognized key
		{
			signer: signerFixture,
			argCC:  ccFixture,
			argKey: "foo",
			err:    oauth2.NewError(oauth2.ErrorInvalidGrant),
			scope:  []string{"openid", "offline_access"},
		},

		// unrecognized client
		{
			signer: signerFixture,
			argCC:  oidc.ClientCredentials{ID: "YYY"},
			argKey: keyFixture,
			err:    oauth2.NewError(oauth2.ErrorInvalidClient),
			scope:  []string{"openid", "offline_access"},
		},

		// signing operation fails
		{
			signer: &StaticSigner{sig: nil, err: errors.New("fail")},
			argCC:  ccFixture,
			argKey: keyFixture,
			err:    oauth2.NewError(oauth2.ErrorServerError),
			scope:  []string{"openid", "offline_access"},
		},
	}

	for i, tt := range tests {
		sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
		sm.GenerateCode = func() (string, error) { return keyFixture, nil }

		sessionID, err := sm.NewSession("connector_id", ccFixture.ID, "bogus", url.URL{}, "", false, tt.scope)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		_, err = sm.AttachRemoteIdentity(sessionID, oidc.Identity{})
		if err != nil {
			t.Errorf("case %d: unexpected error: %v", i, err)
			continue
		}
		km := &StaticKeyManager{
			signer: tt.signer,
		}

		clients := []client.Client{
			client.Client{
				Credentials: ccFixture,
				Metadata: oidc.ClientMetadata{
					RedirectURIs: []url.URL{
						validRedirURL,
					},
				},
			},
		}
		dbm := db.NewMemDB()
		clientIDGenerator := func(hostport string) (string, error) {
			return hostport, nil
		}
		secGen := func() ([]byte, error) {
			return []byte("secret"), nil
		}
		clientRepo := db.NewClientRepo(dbm)
		clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
		if err != nil {
			t.Fatalf("Failed to create client identity manager: %v", err)
		}
		_, err = sm.AttachUser(sessionID, "testid-1")
		if err != nil {
			t.Fatalf("case %d: unexpected error: %v", i, err)
		}

		userRepo, err := makeNewUserRepo()
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo()

		srv := &Server{
			IssuerURL:        issuerURL,
			KeyManager:       km,
			SessionManager:   sm,
			ClientRepo:       clientRepo,
			ClientManager:    clientManager,
			UserRepo:         userRepo,
			RefreshTokenRepo: refreshTokenRepo,
		}

		_, err = sm.NewSessionKey(sessionID)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		jwt, token, err := srv.CodeToken(tt.argCC, tt.argKey)
		if token != tt.refreshToken {
			fmt.Printf("case %d: expect refresh token %q, got %q\n", i, tt.refreshToken, token)
			t.Fatalf("case %d: expect refresh token %q, got %q", i, tt.refreshToken, token)
			panic("")
		}
		if !reflect.DeepEqual(err, tt.err) {
			t.Errorf("case %d: expect %v, got %v", i, tt.err, err)
		}
		if err == nil && jwt == nil {
			t.Errorf("case %d: got nil JWT", i)
		}
		if err != nil && jwt != nil {
			t.Errorf("case %d: got non-nil JWT %v", i, jwt)
		}
	}
}
예제 #24
0
파일: http.go 프로젝트: Tecsisa/dex
func handleAuthFunc(srv OIDCServer, baseURL url.URL, idpcs []connector.Connector, tpl *template.Template, registrationEnabled bool) http.HandlerFunc {
	idx := makeConnectorMap(idpcs)
	return func(w http.ResponseWriter, r *http.Request) {
		if r.Method != "GET" {
			w.Header().Set("Allow", "GET")
			phttp.WriteError(w, http.StatusMethodNotAllowed, "GET only acceptable method")
			return
		}

		q := r.URL.Query()
		register := q.Get("register") == "1" && registrationEnabled
		e := q.Get("error")
		if e != "" {
			sessionKey := q.Get("state")
			if err := srv.KillSession(sessionKey); err != nil {
				log.Errorf("Failed killing sessionKey %q: %v", sessionKey, err)
			}
			renderLoginPage(w, r, srv, idpcs, register, tpl)
			return
		}

		connectorID := q.Get("connector_id")
		idpc, ok := idx[connectorID]
		if !ok {
			renderLoginPage(w, r, srv, idpcs, register, tpl)
			return
		}

		acr, err := oauth2.ParseAuthCodeRequest(q)
		if err != nil {
			log.Errorf("Invalid auth request: %v", err)
			writeAuthError(w, err, acr.State)
			return
		}

		cli, err := srv.Client(acr.ClientID)
		if err != nil {
			log.Errorf("Failed fetching client %q from repo: %v", acr.ClientID, err)
			writeAuthError(w, oauth2.NewError(oauth2.ErrorServerError), acr.State)
			return
		}
		if err == client.ErrorNotFound {
			log.Errorf("Client %q not found", acr.ClientID)
			writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State)
			return
		}

		redirectURL, err := cli.ValidRedirectURL(acr.RedirectURL)
		if err != nil {
			switch err {
			case (client.ErrorCantChooseRedirectURL):
				log.Errorf("Request must provide redirect URL as client %q has registered many", acr.ClientID)
				writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State)
				return
			case (client.ErrorInvalidRedirectURL):
				log.Errorf("Request provided unregistered redirect URL: %s", acr.RedirectURL)
				writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State)
				return
			case (client.ErrorNoValidRedirectURLs):
				log.Errorf("There are no registered URLs for the requested client: %s", acr.RedirectURL)
				writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State)
				return
			}
		}

		if acr.ResponseType != oauth2.ResponseTypeCode {
			log.Errorf("unexpected ResponseType: %v: ", acr.ResponseType)
			redirectAuthError(w, oauth2.NewError(oauth2.ErrorUnsupportedResponseType), acr.State, redirectURL)
			return
		}

		// Check scopes.
		if scopeErr := validateScopes(srv, acr.ClientID, acr.Scope); scopeErr != nil {
			log.Error(scopeErr)
			writeAuthError(w, scopeErr, acr.State)
			return
		}

		nonce := q.Get("nonce")

		key, err := srv.NewSession(connectorID, acr.ClientID, acr.State, redirectURL, nonce, register, acr.Scope)
		if err != nil {
			log.Errorf("Error creating new session: %v: ", err)
			redirectAuthError(w, err, acr.State, redirectURL)
			return
		}

		if register {
			_, ok := idpc.(*connector.LocalConnector)
			if ok {
				q := url.Values{}
				q.Set("code", key)
				ru := path.Join(baseURL.Path, httpPathRegister) + "?" + q.Encode()
				w.Header().Set("Location", ru)
				w.WriteHeader(http.StatusFound)
				return
			}
		}

		var p string
		if register {
			p = "select_account consent"
		}
		if shouldReprompt(r) || register {
			p = "select_account"
		}
		lu, err := idpc.LoginURL(key, p)
		if err != nil {
			log.Errorf("Connector.LoginURL failed: %v", err)
			redirectAuthError(w, err, acr.State, redirectURL)
			return
		}

		http.SetCookie(w, createLastSeenCookie())
		w.Header().Set("Location", lu)
		w.WriteHeader(http.StatusFound)
		return
	}
}
예제 #25
0
파일: error_test.go 프로젝트: Tecsisa/dex
func TestWriteTokenError(t *testing.T) {
	tests := []struct {
		err        error
		state      string
		wantCode   int
		wantHeader http.Header
		wantBody   string
	}{
		{
			err:      oauth2.NewError(oauth2.ErrorInvalidRequest),
			state:    "bazinga",
			wantCode: http.StatusBadRequest,
			wantHeader: http.Header{
				"Content-Type": []string{"application/json"},
			},
			wantBody: `{"error":"invalid_request","state":"bazinga"}`,
		},
		{
			err:      oauth2.NewError(oauth2.ErrorInvalidRequest),
			wantCode: http.StatusBadRequest,
			wantHeader: http.Header{
				"Content-Type": []string{"application/json"},
			},
			wantBody: `{"error":"invalid_request"}`,
		},
		{
			err:      oauth2.NewError(oauth2.ErrorInvalidGrant),
			wantCode: http.StatusBadRequest,
			wantHeader: http.Header{
				"Content-Type": []string{"application/json"},
			},
			wantBody: `{"error":"invalid_grant"}`,
		},
		{
			err:      oauth2.NewError(oauth2.ErrorInvalidClient),
			wantCode: http.StatusUnauthorized,
			wantHeader: http.Header{
				"Content-Type":     []string{"application/json"},
				"Www-Authenticate": []string{"Basic"},
			},
			wantBody: `{"error":"invalid_client"}`,
		},
		{
			err:      oauth2.NewError(oauth2.ErrorServerError),
			wantCode: http.StatusBadRequest,
			wantHeader: http.Header{
				"Content-Type": []string{"application/json"},
			},
			wantBody: `{"error":"server_error"}`,
		},
		{
			err:      oauth2.NewError(oauth2.ErrorUnsupportedGrantType),
			wantCode: http.StatusBadRequest,
			wantHeader: http.Header{
				"Content-Type": []string{"application/json"},
			},
			wantBody: `{"error":"unsupported_grant_type"}`,
		},
		{
			err:      errors.New("generic failure"),
			wantCode: http.StatusBadRequest,
			wantHeader: http.Header{
				"Content-Type": []string{"application/json"},
			},
			wantBody: `{"error":"server_error"}`,
		},
	}

	for i, tt := range tests {
		w := httptest.NewRecorder()
		writeTokenError(w, tt.err, tt.state)

		if tt.wantCode != w.Code {
			t.Errorf("case %d: incorrect HTTP status: want=%d got=%d", i, tt.wantCode, w.Code)
		}

		gotHeader := w.Header()
		if !reflect.DeepEqual(tt.wantHeader, gotHeader) {
			t.Errorf("case %d: incorrect HTTP headers: want=%#v got=%#v", i, tt.wantHeader, gotHeader)
		}

		gotBody := w.Body.String()
		if tt.wantBody != gotBody {
			t.Errorf("case %d: incorrect HTTP body: want=%q got=%q", i, tt.wantBody, gotBody)
		}
	}
}
예제 #26
0
파일: server.go 프로젝트: Tecsisa/dex
func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, string, time.Time, error) {
	ok, err := s.ClientManager.Authenticate(creds)
	if err != nil {
		log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
		return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
	}
	if !ok {
		log.Errorf("Failed to Authenticate client %s", creds.ID)
		return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorInvalidClient)
	}

	userID, connectorID, rtScopes, err := s.RefreshTokenRepo.Verify(creds.ID, token)
	switch err {
	case nil:
		break
	case refresh.ErrorInvalidToken:
		return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorInvalidRequest)
	case refresh.ErrorInvalidClientID:
		return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorInvalidClient)
	default:
		return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
	}

	if len(scopes) == 0 {
		scopes = rtScopes
	} else {
		if !rtScopes.Contains(scopes) {
			return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorInvalidRequest)
		}
	}

	usr, err := s.UserRepo.Get(nil, userID)
	if err != nil {
		// The error can be user.ErrorNotFound, but we are not deleting
		// user at this moment, so this shouldn't happen.
		log.Errorf("Failed to fetch user %q from repo: %v: ", userID, err)
		return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
	}

	var groups []string
	if rtScopes.HasScope(scope.ScopeGroups) {
		conn, ok := s.connector(connectorID)
		if !ok {
			log.Errorf("refresh token contained invalid connector ID (%s)", connectorID)
			return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
		}

		grouper, ok := conn.(connector.GroupsConnector)
		if !ok {
			log.Errorf("refresh token requested groups for connector (%s) that doesn't support groups", connectorID)
			return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
		}

		remoteIdentities, err := s.UserRepo.GetRemoteIdentities(nil, userID)
		if err != nil {
			log.Errorf("failed to get remote identities: %v", err)
			return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
		}
		remoteIdentity, ok := func() (user.RemoteIdentity, bool) {
			for _, ri := range remoteIdentities {
				if ri.ConnectorID == connectorID {
					return ri, true
				}
			}
			return user.RemoteIdentity{}, false
		}()
		if !ok {
			log.Errorf("failed to get remote identity for connector %s", connectorID)
			return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
		}
		if groups, err = grouper.Groups(remoteIdentity.ID); err != nil {
			log.Errorf("failed to get groups for refresh token: %v", connectorID)
			return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
		}
	}

	signer, err := s.KeyManager.Signer()
	if err != nil {
		log.Errorf("Failed to refresh ID token: %v", err)
		return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
	}

	now := time.Now()
	expiresAt := now.Add(session.DefaultSessionValidityWindow)

	claims := oidc.NewClaims(s.IssuerURL.String(), usr.ID, creds.ID, now, expiresAt)
	usr.AddToClaims(claims)
	if rtScopes.HasScope(scope.ScopeGroups) {
		if groups == nil {
			groups = []string{}
		}
		claims["groups"] = groups
	}

	s.addClaimsFromScope(claims, scope.Scopes(scopes), creds.ID)

	jwt, err := jose.NewSignedJWT(claims, signer)
	if err != nil {
		log.Errorf("Failed to generate ID token: %v", err)
		return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
	}

	refreshToken, err := s.RefreshTokenRepo.RenewRefreshToken(creds.ID, userID, token)
	if err != nil {
		log.Errorf("Failed to generate new refresh token: %v", err)
		return nil, "", time.Time{}, oauth2.NewError(oauth2.ErrorServerError)
	}

	log.Infof("New token sent: clientID=%s", creds.ID)

	return jwt, refreshToken, expiresAt, nil
}