예제 #1
0
파일: api.go 프로젝트: jwmaag/dex
func (d dexAPI) CreateClient(ctx context.Context, req *api.CreateClientReq) (*api.CreateClientResp, error) {
	if req.Client == nil {
		return nil, errors.New("no client supplied")
	}

	if req.Client.Id == "" {
		req.Client.Id = storage.NewID()
	}
	if req.Client.Secret == "" {
		req.Client.Secret = storage.NewID() + storage.NewID()
	}

	c := storage.Client{
		ID:           req.Client.Id,
		Secret:       req.Client.Secret,
		RedirectURIs: req.Client.RedirectUris,
		TrustedPeers: req.Client.TrustedPeers,
		Public:       req.Client.Public,
		Name:         req.Client.Name,
		LogoURL:      req.Client.LogoUrl,
	}
	if err := d.s.CreateClient(c); err != nil {
		d.logger.Errorf("api: failed to create client: %v", err)
		// TODO(ericchiang): Surface "already exists" errors.
		return nil, fmt.Errorf("create client: %v", err)
	}

	return &api.CreateClientResp{
		Client: req.Client,
	}, nil
}
예제 #2
0
파일: handlers.go 프로젝트: jwmaag/dex
func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
	start := s.now()
	err := func() error {
		// Instead of trying to introspect health, just try to use the underlying storage.
		a := storage.AuthRequest{
			ID:       storage.NewID(),
			ClientID: storage.NewID(),

			// Set a short expiry so if the delete fails this will be cleaned up quickly by garbage collection.
			Expiry: s.now().Add(time.Minute),
		}

		if err := s.storage.CreateAuthRequest(a); err != nil {
			return fmt.Errorf("create auth request: %v", err)
		}
		if err := s.storage.DeleteAuthRequest(a.ID); err != nil {
			return fmt.Errorf("delete auth request: %v", err)
		}
		return nil
	}()

	t := s.now().Sub(start)
	if err != nil {
		s.logger.Errorf("Storage health check failed: %v", err)
		s.renderError(w, http.StatusInternalServerError, "Health check failed.")
		return
	}
	fmt.Fprintf(w, "Health check passed in %s", t)
}
예제 #3
0
func testClientConcurrentUpdate(t *testing.T, s storage.Storage) {
	c := storage.Client{
		ID:           storage.NewID(),
		Secret:       "foobar",
		RedirectURIs: []string{"foo://bar.com/", "https://auth.example.com"},
		Name:         "dex client",
		LogoURL:      "https://goo.gl/JIyzIC",
	}

	if err := s.CreateClient(c); err != nil {
		t.Fatalf("create client: %v", err)
	}

	var err1, err2 error

	err1 = s.UpdateClient(c.ID, func(old storage.Client) (storage.Client, error) {
		old.Secret = "new secret 1"
		err2 = s.UpdateClient(c.ID, func(old storage.Client) (storage.Client, error) {
			old.Secret = "new secret 2"
			return old, nil
		})
		return old, nil
	})

	if (err1 == nil) == (err2 == nil) {
		t.Errorf("update client:\nupdate1: %v\nupdate2: %v\n", err1, err2)
	}
}
예제 #4
0
파일: handlers.go 프로젝트: jwmaag/dex
func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, refreshToken string, expiry time.Time) {
	// TODO(ericchiang): figure out an access token story and support the user info
	// endpoint. For now use a random value so no one depends on the access_token
	// holding a specific structure.
	resp := struct {
		AccessToken  string `json:"access_token"`
		TokenType    string `json:"token_type"`
		ExpiresIn    int    `json:"expires_in"`
		RefreshToken string `json:"refresh_token,omitempty"`
		IDToken      string `json:"id_token"`
	}{
		storage.NewID(),
		"bearer",
		int(expiry.Sub(s.now()).Seconds()),
		refreshToken,
		idToken,
	}
	data, err := json.Marshal(resp)
	if err != nil {
		s.logger.Errorf("failed to marshal access token response: %v", err)
		s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
		return
	}
	w.Header().Set("Content-Type", "application/json")
	w.Header().Set("Content-Length", strconv.Itoa(len(data)))
	w.Write(data)
}
예제 #5
0
파일: conformance.go 프로젝트: jwmaag/dex
func testClientCRUD(t *testing.T, s storage.Storage) {
	id := storage.NewID()
	c := storage.Client{
		ID:           id,
		Secret:       "foobar",
		RedirectURIs: []string{"foo://bar.com/", "https://auth.example.com"},
		Name:         "dex client",
		LogoURL:      "https://goo.gl/JIyzIC",
	}
	err := s.DeleteClient(id)
	mustBeErrNotFound(t, "client", err)

	if err := s.CreateClient(c); err != nil {
		t.Fatalf("create client: %v", err)
	}

	getAndCompare := func(id string, want storage.Client) {
		gc, err := s.GetClient(id)
		if err != nil {
			t.Errorf("get client: %v", err)
			return
		}
		if diff := pretty.Compare(want, gc); diff != "" {
			t.Errorf("client retrieved from storage did not match: %s", diff)
		}
	}

	getAndCompare(id, c)

	newSecret := "barfoo"
	err = s.UpdateClient(id, func(old storage.Client) (storage.Client, error) {
		old.Secret = newSecret
		return old, nil
	})
	if err != nil {
		t.Errorf("update client: %v", err)
	}
	c.Secret = newSecret
	getAndCompare(id, c)

	if err := s.DeleteClient(id); err != nil {
		t.Fatalf("delete client: %v", err)
	}

	_, err = s.GetClient(id)
	mustBeErrNotFound(t, "client", err)
}
예제 #6
0
파일: conformance.go 프로젝트: jwmaag/dex
func testAuthRequestCRUD(t *testing.T, s storage.Storage) {
	a := storage.AuthRequest{
		ID:                  storage.NewID(),
		ClientID:            "foobar",
		ResponseTypes:       []string{"code"},
		Scopes:              []string{"openid", "email"},
		RedirectURI:         "https://*****:*****@example.com",
			EmailVerified: true,
			Groups:        []string{"a", "b"},
		},
	}

	identity := storage.Claims{Email: "foobar"}

	if err := s.CreateAuthRequest(a); err != nil {
		t.Fatalf("failed creating auth request: %v", err)
	}
	if err := s.UpdateAuthRequest(a.ID, func(old storage.AuthRequest) (storage.AuthRequest, error) {
		old.Claims = identity
		old.ConnectorID = "connID"
		return old, nil
	}); err != nil {
		t.Fatalf("failed to update auth request: %v", err)
	}

	got, err := s.GetAuthRequest(a.ID)
	if err != nil {
		t.Fatalf("failed to get auth req: %v", err)
	}
	if !reflect.DeepEqual(got.Claims, identity) {
		t.Fatalf("update failed, wanted identity=%#v got %#v", identity, got.Claims)
	}
}
예제 #7
0
파일: conformance.go 프로젝트: jwmaag/dex
// testTimezones tests that backends either fully support timezones or
// do the correct standardization.
func testTimezones(t *testing.T, s storage.Storage) {
	est, err := time.LoadLocation("America/New_York")
	if err != nil {
		t.Fatal(err)
	}
	// Create an expiry with timezone info. Only expect backends to be
	// accurate to the millisecond
	expiry := time.Now().In(est).Round(time.Millisecond)

	c := storage.AuthCode{
		ID:            storage.NewID(),
		ClientID:      "foobar",
		RedirectURI:   "https://*****:*****@example.com",
			EmailVerified: true,
			Groups:        []string{"a", "b"},
		},
	}
	if err := s.CreateAuthCode(c); err != nil {
		t.Fatalf("failed creating auth code: %v", err)
	}
	got, err := s.GetAuthCode(c.ID)
	if err != nil {
		t.Fatalf("failed to get auth code: %v", err)
	}

	// Ensure that if the resulting time is converted to the same
	// timezone, it's the same value. We DO NOT expect timezones
	// to be preserved.
	gotTime := got.Expiry.In(est)
	wantTime := expiry
	if !gotTime.Equal(wantTime) {
		t.Fatalf("expected expiry %v got %v", wantTime, gotTime)
	}
}
예제 #8
0
func testAuthRequestConcurrentUpdate(t *testing.T, s storage.Storage) {
	a := storage.AuthRequest{
		ID:                  storage.NewID(),
		ClientID:            "foobar",
		ResponseTypes:       []string{"code"},
		Scopes:              []string{"openid", "email"},
		RedirectURI:         "https://*****:*****@example.com",
			EmailVerified: true,
			Groups:        []string{"a", "b"},
		},
	}

	if err := s.CreateAuthRequest(a); err != nil {
		t.Fatalf("failed creating auth request: %v", err)
	}

	var err1, err2 error

	err1 = s.UpdateAuthRequest(a.ID, func(old storage.AuthRequest) (storage.AuthRequest, error) {
		old.State = "state 1"
		err2 = s.UpdateAuthRequest(a.ID, func(old storage.AuthRequest) (storage.AuthRequest, error) {
			old.State = "state 2"
			return old, nil
		})
		return old, nil
	})

	if (err1 == nil) == (err2 == nil) {
		t.Errorf("update auth request:\nupdate1: %v\nupdate2: %v\n", err1, err2)
	}
}
예제 #9
0
파일: conformance.go 프로젝트: jwmaag/dex
func testAuthCodeCRUD(t *testing.T, s storage.Storage) {
	a := storage.AuthCode{
		ID:            storage.NewID(),
		ClientID:      "foobar",
		RedirectURI:   "https://*****:*****@example.com",
			EmailVerified: true,
			Groups:        []string{"a", "b"},
		},
	}

	if err := s.CreateAuthCode(a); err != nil {
		t.Fatalf("failed creating auth code: %v", err)
	}

	got, err := s.GetAuthCode(a.ID)
	if err != nil {
		t.Fatalf("failed to get auth req: %v", err)
	}
	if a.Expiry.Unix() != got.Expiry.Unix() {
		t.Errorf("auth code expiry did not match want=%s vs got=%s", a.Expiry, got.Expiry)
	}
	got.Expiry = a.Expiry // time fields do not compare well
	if diff := pretty.Compare(a, got); diff != "" {
		t.Errorf("auth code retrieved from storage did not match: %s", diff)
	}

	if err := s.DeleteAuthCode(a.ID); err != nil {
		t.Fatalf("delete auth code: %v", err)
	}

	_, err = s.GetAuthCode(a.ID)
	mustBeErrNotFound(t, "auth code", err)
}
예제 #10
0
파일: conformance.go 프로젝트: jwmaag/dex
func testRefreshTokenCRUD(t *testing.T, s storage.Storage) {
	id := storage.NewID()
	refresh := storage.RefreshToken{
		RefreshToken: id,
		ClientID:     "client_id",
		ConnectorID:  "client_secret",
		Scopes:       []string{"openid", "email", "profile"},
		Claims: storage.Claims{
			UserID:        "1",
			Username:      "******",
			Email:         "*****@*****.**",
			EmailVerified: true,
			Groups:        []string{"a", "b"},
		},
	}
	if err := s.CreateRefresh(refresh); err != nil {
		t.Fatalf("create refresh token: %v", err)
	}

	getAndCompare := func(id string, want storage.RefreshToken) {
		gr, err := s.GetRefresh(id)
		if err != nil {
			t.Errorf("get refresh: %v", err)
			return
		}
		if diff := pretty.Compare(want, gr); diff != "" {
			t.Errorf("refresh token retrieved from storage did not match: %s", diff)
		}
	}

	getAndCompare(id, refresh)

	if err := s.DeleteRefresh(id); err != nil {
		t.Fatalf("failed to delete refresh request: %v", err)
	}

	if _, err := s.GetRefresh(id); err != storage.ErrNotFound {
		t.Errorf("after deleting refresh expected storage.ErrNotFound, got %v", err)
	}
}
예제 #11
0
파일: conformance.go 프로젝트: jwmaag/dex
func testGC(t *testing.T, s storage.Storage) {
	est, err := time.LoadLocation("America/New_York")
	if err != nil {
		t.Fatal(err)
	}
	pst, err := time.LoadLocation("America/Los_Angeles")
	if err != nil {
		t.Fatal(err)
	}

	expiry := time.Now().In(est)
	c := storage.AuthCode{
		ID:            storage.NewID(),
		ClientID:      "foobar",
		RedirectURI:   "https://*****:*****@example.com",
			EmailVerified: true,
			Groups:        []string{"a", "b"},
		},
	}

	if err := s.CreateAuthCode(c); err != nil {
		t.Fatalf("failed creating auth code: %v", err)
	}

	for _, tz := range []*time.Location{time.UTC, est, pst} {
		result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz))
		if err != nil {
			t.Errorf("garbage collection failed: %v", err)
		} else {
			if result.AuthCodes != 0 || result.AuthRequests != 0 {
				t.Errorf("expected no garbage collection results, got %#v", result)
			}
		}
		if _, err := s.GetAuthCode(c.ID); err != nil {
			t.Errorf("expected to be able to get auth code after GC: %v", err)
		}
	}

	if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil {
		t.Errorf("garbage collection failed: %v", err)
	} else if r.AuthCodes != 1 {
		t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthCodes)
	}

	if _, err := s.GetAuthCode(c.ID); err == nil {
		t.Errorf("expected auth code to be GC'd")
	} else if err != storage.ErrNotFound {
		t.Errorf("expected storage.ErrNotFound, got %v", err)
	}

	a := storage.AuthRequest{
		ID:                  storage.NewID(),
		ClientID:            "foobar",
		ResponseTypes:       []string{"code"},
		Scopes:              []string{"openid", "email"},
		RedirectURI:         "https://*****:*****@example.com",
			EmailVerified: true,
			Groups:        []string{"a", "b"},
		},
	}

	if err := s.CreateAuthRequest(a); err != nil {
		t.Fatalf("failed creating auth request: %v", err)
	}

	for _, tz := range []*time.Location{time.UTC, est, pst} {
		result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz))
		if err != nil {
			t.Errorf("garbage collection failed: %v", err)
		} else {
			if result.AuthCodes != 0 || result.AuthRequests != 0 {
				t.Errorf("expected no garbage collection results, got %#v", result)
			}
		}
		if _, err := s.GetAuthRequest(a.ID); err != nil {
			t.Errorf("expected to be able to get auth code after GC: %v", err)
		}
	}

	if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil {
		t.Errorf("garbage collection failed: %v", err)
	} else if r.AuthRequests != 1 {
		t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthRequests)
	}

	if _, err := s.GetAuthRequest(a.ID); err == nil {
		t.Errorf("expected auth code to be GC'd")
	} else if err != storage.ErrNotFound {
		t.Errorf("expected storage.ErrNotFound, got %v", err)
	}
}
예제 #12
0
파일: oauth2.go 프로젝트: ericchiang/dex
// parse the initial request from the OAuth2 client.
//
// For correctness the logic is largely copied from https://github.com/RangelReale/osin.
func parseAuthorizationRequest(s storage.Storage, supportedResponseTypes map[string]bool, r *http.Request) (req storage.AuthRequest, oauth2Err *authErr) {
	if err := r.ParseForm(); err != nil {
		return req, &authErr{"", "", errInvalidRequest, "Failed to parse request."}
	}

	redirectURI, err := url.QueryUnescape(r.Form.Get("redirect_uri"))
	if err != nil {
		return req, &authErr{"", "", errInvalidRequest, "No redirect_uri provided."}
	}
	state := r.FormValue("state")

	clientID := r.Form.Get("client_id")

	client, err := s.GetClient(clientID)
	if err != nil {
		if err == storage.ErrNotFound {
			description := fmt.Sprintf("Invalid client_id (%q).", clientID)
			return req, &authErr{"", "", errUnauthorizedClient, description}
		}
		log.Printf("Failed to get client: %v", err)
		return req, &authErr{"", "", errServerError, ""}
	}

	if !validateRedirectURI(client, redirectURI) {
		description := fmt.Sprintf("Unregistered redirect_uri (%q).", redirectURI)
		return req, &authErr{"", "", errInvalidRequest, description}
	}

	newErr := func(typ, format string, a ...interface{}) *authErr {
		return &authErr{state, redirectURI, typ, fmt.Sprintf(format, a...)}
	}

	scopes := strings.Fields(r.Form.Get("scope"))

	var (
		unrecognized  []string
		invalidScopes []string
	)
	hasOpenIDScope := false
	for _, scope := range scopes {
		switch scope {
		case scopeOpenID:
			hasOpenIDScope = true
		case scopeOfflineAccess, scopeEmail, scopeProfile, scopeGroups:
		default:
			peerID, ok := parseCrossClientScope(scope)
			if !ok {
				unrecognized = append(unrecognized, scope)
				continue
			}

			isTrusted, err := validateCrossClientTrust(s, clientID, peerID)
			if err != nil {
				return req, newErr(errServerError, "")
			}
			if !isTrusted {
				invalidScopes = append(invalidScopes, scope)
			}
		}
	}
	if !hasOpenIDScope {
		return req, newErr("invalid_scope", `Missing required scope(s) ["openid"].`)
	}
	if len(unrecognized) > 0 {
		return req, newErr("invalid_scope", "Unrecognized scope(s) %q", unrecognized)
	}
	if len(invalidScopes) > 0 {
		return req, newErr("invalid_scope", "Client can't request scope(s) %q", invalidScopes)
	}

	nonce := r.Form.Get("nonce")
	responseTypes := strings.Split(r.Form.Get("response_type"), " ")
	for _, responseType := range responseTypes {
		if !supportedResponseTypes[responseType] {
			return req, newErr("invalid_request", "Invalid response type %q", responseType)
		}

		switch responseType {
		case responseTypeCode:
		case responseTypeToken:
			// Implicit flow requires a nonce value.
			// https://openid.net/specs/openid-connect-core-1_0.html#ImplicitAuthRequest
			if nonce == "" {
				return req, newErr("invalid_request", "Response type 'token' requires a 'nonce' value.")
			}

			if redirectURI == redirectURIOOB {
				err := fmt.Sprintf("Cannot use response type 'token' with redirect_uri '%s'.", redirectURIOOB)
				return req, newErr("invalid_request", err)
			}
		default:
			return req, newErr("invalid_request", "Invalid response type %q", responseType)
		}
	}

	return storage.AuthRequest{
		ID:                  storage.NewID(),
		ClientID:            client.ID,
		State:               r.Form.Get("state"),
		Nonce:               nonce,
		ForceApprovalPrompt: r.Form.Get("approval_prompt") == "force",
		Scopes:              scopes,
		RedirectURI:         redirectURI,
		ResponseTypes:       responseTypes,
	}, nil
}
예제 #13
0
파일: handlers.go 프로젝트: jwmaag/dex
// handle a refresh token request https://tools.ietf.org/html/rfc6749#section-6
func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, client storage.Client) {
	code := r.PostFormValue("refresh_token")
	scope := r.PostFormValue("scope")
	if code == "" {
		s.tokenErrHelper(w, errInvalidRequest, "No refresh token in request.", http.StatusBadRequest)
		return
	}

	refresh, err := s.storage.GetRefresh(code)
	if err != nil || refresh.ClientID != client.ID {
		if err != storage.ErrNotFound {
			s.logger.Errorf("failed to get auth code: %v", err)
			s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
		} else {
			s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest)
		}
		return
	}

	// Per the OAuth2 spec, if the client has omitted the scopes, default to the original
	// authorized scopes.
	//
	// https://tools.ietf.org/html/rfc6749#section-6
	scopes := refresh.Scopes
	if scope != "" {
		requestedScopes := strings.Fields(scope)
		var unauthorizedScopes []string

		for _, s := range requestedScopes {
			contains := func() bool {
				for _, scope := range refresh.Scopes {
					if s == scope {
						return true
					}
				}
				return false
			}()
			if !contains {
				unauthorizedScopes = append(unauthorizedScopes, s)
			}
		}

		if len(unauthorizedScopes) > 0 {
			msg := fmt.Sprintf("Requested scopes contain unauthorized scope(s): %q.", unauthorizedScopes)
			s.tokenErrHelper(w, errInvalidRequest, msg, http.StatusBadRequest)
			return
		}
		scopes = requestedScopes
	}

	conn, ok := s.connectors[refresh.ConnectorID]
	if !ok {
		s.logger.Errorf("connector ID not found: %q", refresh.ConnectorID)
		s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
		return
	}

	// Can the connector refresh the identity? If so, attempt to refresh the data
	// in the connector.
	//
	// TODO(ericchiang): We may want a strict mode where connectors that don't implement
	// this interface can't perform refreshing.
	if refreshConn, ok := conn.Connector.(connector.RefreshConnector); ok {
		ident := connector.Identity{
			UserID:        refresh.Claims.UserID,
			Username:      refresh.Claims.Username,
			Email:         refresh.Claims.Email,
			EmailVerified: refresh.Claims.EmailVerified,
			Groups:        refresh.Claims.Groups,
			ConnectorData: refresh.ConnectorData,
		}
		ident, err := refreshConn.Refresh(r.Context(), parseScopes(scopes), ident)
		if err != nil {
			s.logger.Errorf("failed to refresh identity: %v", err)
			s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
			return
		}

		// Update the claims of the refresh token.
		//
		// UserID intentionally ignored for now.
		refresh.Claims.Username = ident.Username
		refresh.Claims.Email = ident.Email
		refresh.Claims.EmailVerified = ident.EmailVerified
		refresh.Claims.Groups = ident.Groups
		refresh.ConnectorData = ident.ConnectorData
	}

	idToken, expiry, err := s.newIDToken(client.ID, refresh.Claims, scopes, refresh.Nonce)
	if err != nil {
		s.logger.Errorf("failed to create ID token: %v", err)
		s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
		return
	}

	// Refresh tokens are claimed exactly once. Delete the current token and
	// create a new one.
	if err := s.storage.DeleteRefresh(code); err != nil {
		s.logger.Errorf("failed to delete auth code: %v", err)
		s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
		return
	}
	refresh.RefreshToken = storage.NewID()
	if err := s.storage.CreateRefresh(refresh); err != nil {
		s.logger.Errorf("failed to create refresh token: %v", err)
		s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
		return
	}
	s.writeAccessToken(w, idToken, refresh.RefreshToken, expiry)
}
예제 #14
0
파일: handlers.go 프로젝트: jwmaag/dex
// handle an access token request https://tools.ietf.org/html/rfc6749#section-4.1.3
func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client storage.Client) {
	code := r.PostFormValue("code")
	redirectURI := r.PostFormValue("redirect_uri")

	authCode, err := s.storage.GetAuthCode(code)
	if err != nil || s.now().After(authCode.Expiry) || authCode.ClientID != client.ID {
		if err != storage.ErrNotFound {
			s.logger.Errorf("failed to get auth code: %v", err)
			s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
		} else {
			s.tokenErrHelper(w, errInvalidRequest, "Invalid or expired code parameter.", http.StatusBadRequest)
		}
		return
	}

	if authCode.RedirectURI != redirectURI {
		s.tokenErrHelper(w, errInvalidRequest, "redirect_uri did not match URI from initial request.", http.StatusBadRequest)
		return
	}

	idToken, expiry, err := s.newIDToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce)
	if err != nil {
		s.logger.Errorf("failed to create ID token: %v", err)
		s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
		return
	}

	if err := s.storage.DeleteAuthCode(code); err != nil {
		s.logger.Errorf("failed to delete auth code: %v", err)
		s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
		return
	}

	reqRefresh := func() bool {
		for _, scope := range authCode.Scopes {
			if scope == scopeOfflineAccess {
				return true
			}
		}
		return false
	}()
	var refreshToken string
	if reqRefresh {
		refresh := storage.RefreshToken{
			RefreshToken:  storage.NewID(),
			ClientID:      authCode.ClientID,
			ConnectorID:   authCode.ConnectorID,
			Scopes:        authCode.Scopes,
			Claims:        authCode.Claims,
			Nonce:         authCode.Nonce,
			ConnectorData: authCode.ConnectorData,
		}
		if err := s.storage.CreateRefresh(refresh); err != nil {
			s.logger.Errorf("failed to create refresh token: %v", err)
			s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
			return
		}
		refreshToken = refresh.RefreshToken
	}
	s.writeAccessToken(w, idToken, refreshToken, expiry)
}
예제 #15
0
파일: handlers.go 프로젝트: jwmaag/dex
func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authReq storage.AuthRequest) {
	if s.now().After(authReq.Expiry) {
		s.renderError(w, http.StatusBadRequest, "User session has expired.")
		return
	}

	if err := s.storage.DeleteAuthRequest(authReq.ID); err != nil {
		if err != storage.ErrNotFound {
			s.logger.Errorf("Failed to delete authorization request: %v", err)
			s.renderError(w, http.StatusInternalServerError, "Internal server error.")
		} else {
			s.renderError(w, http.StatusBadRequest, "User session error.")
		}
		return
	}
	u, err := url.Parse(authReq.RedirectURI)
	if err != nil {
		s.renderError(w, http.StatusInternalServerError, "Invalid redirect URI.")
		return
	}
	q := u.Query()

	for _, responseType := range authReq.ResponseTypes {
		switch responseType {
		case responseTypeCode:
			code := storage.AuthCode{
				ID:            storage.NewID(),
				ClientID:      authReq.ClientID,
				ConnectorID:   authReq.ConnectorID,
				Nonce:         authReq.Nonce,
				Scopes:        authReq.Scopes,
				Claims:        authReq.Claims,
				Expiry:        s.now().Add(time.Minute * 30),
				RedirectURI:   authReq.RedirectURI,
				ConnectorData: authReq.ConnectorData,
			}
			if err := s.storage.CreateAuthCode(code); err != nil {
				s.logger.Errorf("Failed to create auth code: %v", err)
				s.renderError(w, http.StatusInternalServerError, "Internal server error.")
				return
			}

			if authReq.RedirectURI == redirectURIOOB {
				if err := s.templates.oob(w, code.ID); err != nil {
					s.logger.Errorf("Server template error: %v", err)
				}
				return
			}
			q.Set("code", code.ID)
		case responseTypeToken:
			idToken, expiry, err := s.newIDToken(authReq.ClientID, authReq.Claims, authReq.Scopes, authReq.Nonce)
			if err != nil {
				s.logger.Errorf("failed to create ID token: %v", err)
				s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
				return
			}
			v := url.Values{}
			v.Set("access_token", storage.NewID())
			v.Set("token_type", "bearer")
			v.Set("id_token", idToken)
			v.Set("state", authReq.State)
			v.Set("expires_in", strconv.Itoa(int(expiry.Sub(s.now()).Seconds())))
			u.Fragment = v.Encode()
		}
	}

	q.Set("state", authReq.State)
	u.RawQuery = q.Encode()
	http.Redirect(w, r, u.String(), http.StatusSeeOther)
}
예제 #16
0
func testGC(t *testing.T, s storage.Storage) {
	n := time.Now().UTC()
	c := storage.AuthCode{
		ID:            storage.NewID(),
		ClientID:      "foobar",
		RedirectURI:   "https://*****:*****@example.com",
			EmailVerified: true,
			Groups:        []string{"a", "b"},
		},
	}

	if err := s.CreateAuthCode(c); err != nil {
		t.Fatalf("failed creating auth code: %v", err)
	}

	if _, err := s.GarbageCollect(n); err != nil {
		t.Errorf("garbage collection failed: %v", err)
	}
	if _, err := s.GetAuthCode(c.ID); err != nil {
		t.Errorf("expected to be able to get auth code after GC: %v", err)
	}

	if r, err := s.GarbageCollect(n.Add(time.Minute)); err != nil {
		t.Errorf("garbage collection failed: %v", err)
	} else if r.AuthCodes != 1 {
		t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthCodes)
	}

	if _, err := s.GetAuthCode(c.ID); err == nil {
		t.Errorf("expected auth code to be GC'd")
	} else if err != storage.ErrNotFound {
		t.Errorf("expected storage.ErrNotFound, got %v", err)
	}

	a := storage.AuthRequest{
		ID:                  storage.NewID(),
		ClientID:            "foobar",
		ResponseTypes:       []string{"code"},
		Scopes:              []string{"openid", "email"},
		RedirectURI:         "https://*****:*****@example.com",
			EmailVerified: true,
			Groups:        []string{"a", "b"},
		},
	}

	if err := s.CreateAuthRequest(a); err != nil {
		t.Fatalf("failed creating auth request: %v", err)
	}

	if _, err := s.GarbageCollect(n); err != nil {
		t.Errorf("garbage collection failed: %v", err)
	}
	if _, err := s.GetAuthRequest(a.ID); err != nil {
		t.Errorf("expected to be able to get auth code after GC: %v", err)
	}

	if r, err := s.GarbageCollect(n.Add(time.Minute)); err != nil {
		t.Errorf("garbage collection failed: %v", err)
	} else if r.AuthRequests != 1 {
		t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthRequests)
	}

	if _, err := s.GetAuthRequest(a.ID); err == nil {
		t.Errorf("expected auth code to be GC'd")
	} else if err != storage.ErrNotFound {
		t.Errorf("expected storage.ErrNotFound, got %v", err)
	}
}