Example #1
0
func testKeysCRUD(t *testing.T, s storage.Storage) {
	updateAndCompare := func(k storage.Keys) {
		err := s.UpdateKeys(func(oldKeys storage.Keys) (storage.Keys, error) {
			return k, nil
		})
		if err != nil {
			t.Errorf("failed to update keys: %v", err)
			return
		}

		if got, err := s.GetKeys(); err != nil {
			t.Errorf("failed to get keys: %v", err)
		} else {
			got.NextRotation = got.NextRotation.UTC()
			if diff := pretty.Compare(k, got); diff != "" {
				t.Errorf("got keys did not equal expected: %s", diff)
			}
		}
	}

	// Postgres isn't as accurate with nano seconds as we'd like
	n := time.Now().UTC().Round(time.Second)

	keys1 := storage.Keys{
		SigningKey:    jsonWebKeys[0].Private,
		SigningKeyPub: jsonWebKeys[0].Public,
		NextRotation:  n,
	}

	keys2 := storage.Keys{
		SigningKey:    jsonWebKeys[2].Private,
		SigningKeyPub: jsonWebKeys[2].Public,
		NextRotation:  n.Add(time.Hour),
		VerificationKeys: []storage.VerificationKey{
			{
				PublicKey: jsonWebKeys[0].Public,
				Expiry:    n.Add(time.Hour),
			},
			{
				PublicKey: jsonWebKeys[1].Public,
				Expiry:    n.Add(time.Hour * 2),
			},
		},
	}

	updateAndCompare(keys1)
	updateAndCompare(keys2)
}
Example #2
0
func startGarbageCollection(ctx context.Context, s storage.Storage, frequency time.Duration, now func() time.Time) {
	go func() {
		for {
			select {
			case <-ctx.Done():
				return
			case <-time.After(frequency):
				if r, err := s.GarbageCollect(now()); err != nil {
					log.Printf("garbage collection failed: %v", err)
				} else if r.AuthRequests > 0 || r.AuthCodes > 0 {
					log.Printf("garbage collection run, delete auth requests=%d, auth codes=%d", r.AuthRequests, r.AuthCodes)
				}
			}
		}
	}()
	return
}
Example #3
0
// testTimezones tests that backends either fully support timezones or
// do the correct standardization.
func testTimezones(t *testing.T, s storage.Storage) {
	est, err := time.LoadLocation("America/New_York")
	if err != nil {
		t.Fatal(err)
	}
	// Create an expiry with timezone info. Only expect backends to be
	// accurate to the millisecond
	expiry := time.Now().In(est).Round(time.Millisecond)

	c := storage.AuthCode{
		ID:            storage.NewID(),
		ClientID:      "foobar",
		RedirectURI:   "https://*****:*****@example.com",
			EmailVerified: true,
			Groups:        []string{"a", "b"},
		},
	}
	if err := s.CreateAuthCode(c); err != nil {
		t.Fatalf("failed creating auth code: %v", err)
	}
	got, err := s.GetAuthCode(c.ID)
	if err != nil {
		t.Fatalf("failed to get auth code: %v", err)
	}

	// Ensure that if the resulting time is converted to the same
	// timezone, it's the same value. We DO NOT expect timezones
	// to be preserved.
	gotTime := got.Expiry.In(est)
	wantTime := expiry
	if !gotTime.Equal(wantTime) {
		t.Fatalf("expected expiry %v got %v", wantTime, gotTime)
	}
}
Example #4
0
func validateCrossClientTrust(s storage.Storage, clientID, peerID string) (trusted bool, err error) {
	if peerID == clientID {
		return true, nil
	}
	peer, err := s.GetClient(peerID)
	if err != nil {
		if err != storage.ErrNotFound {
			log.Printf("Failed to get client: %v", err)
			return false, err
		}
		return false, nil
	}
	for _, id := range peer.TrustedPeers {
		if id == clientID {
			return true, nil
		}
	}
	return false, nil
}
Example #5
0
func testClientConcurrentUpdate(t *testing.T, s storage.Storage) {
	c := storage.Client{
		ID:           storage.NewID(),
		Secret:       "foobar",
		RedirectURIs: []string{"foo://bar.com/", "https://auth.example.com"},
		Name:         "dex client",
		LogoURL:      "https://goo.gl/JIyzIC",
	}

	if err := s.CreateClient(c); err != nil {
		t.Fatalf("create client: %v", err)
	}

	var err1, err2 error

	err1 = s.UpdateClient(c.ID, func(old storage.Client) (storage.Client, error) {
		old.Secret = "new secret 1"
		err2 = s.UpdateClient(c.ID, func(old storage.Client) (storage.Client, error) {
			old.Secret = "new secret 2"
			return old, nil
		})
		return old, nil
	})

	if (err1 == nil) == (err2 == nil) {
		t.Errorf("update client:\nupdate1: %v\nupdate2: %v\n", err1, err2)
	}
}
Example #6
0
func testKeysConcurrentUpdate(t *testing.T, s storage.Storage) {
	// Test twice. Once for a create, once for an update.
	for i := 0; i < 2; i++ {
		n := time.Now().UTC().Round(time.Second)
		keys1 := storage.Keys{
			SigningKey:    jsonWebKeys[0].Private,
			SigningKeyPub: jsonWebKeys[0].Public,
			NextRotation:  n,
		}

		keys2 := storage.Keys{
			SigningKey:    jsonWebKeys[2].Private,
			SigningKeyPub: jsonWebKeys[2].Public,
			NextRotation:  n.Add(time.Hour),
			VerificationKeys: []storage.VerificationKey{
				{
					PublicKey: jsonWebKeys[0].Public,
					Expiry:    n.Add(time.Hour),
				},
				{
					PublicKey: jsonWebKeys[1].Public,
					Expiry:    n.Add(time.Hour * 2),
				},
			},
		}

		var err1, err2 error

		err1 = s.UpdateKeys(func(old storage.Keys) (storage.Keys, error) {
			err2 = s.UpdateKeys(func(old storage.Keys) (storage.Keys, error) {
				return keys1, nil
			})
			return keys2, nil
		})

		if (err1 == nil) == (err2 == nil) {
			t.Errorf("update keys: concurrent updates both returned no error")
		}
	}
}
Example #7
0
func testPasswordConcurrentUpdate(t *testing.T, s storage.Storage) {
	// Use bcrypt.MinCost to keep the tests short.
	passwordHash, err := bcrypt.GenerateFromPassword([]byte("secret"), bcrypt.MinCost)
	if err != nil {
		t.Fatal(err)
	}

	password := storage.Password{
		Email:    "*****@*****.**",
		Hash:     passwordHash,
		Username: "******",
		UserID:   "foobar",
	}
	if err := s.CreatePassword(password); err != nil {
		t.Fatalf("create password token: %v", err)
	}

	var err1, err2 error

	err1 = s.UpdatePassword(password.Email, func(old storage.Password) (storage.Password, error) {
		old.Username = "******"
		err2 = s.UpdatePassword(password.Email, func(old storage.Password) (storage.Password, error) {
			old.Username = "******"
			return old, nil
		})
		return old, nil
	})

	if (err1 == nil) == (err2 == nil) {
		t.Errorf("update password: concurrent updates both returned no error")
	}
}
Example #8
0
func testAuthCodeCRUD(t *testing.T, s storage.Storage) {
	a := storage.AuthCode{
		ID:            storage.NewID(),
		ClientID:      "foobar",
		RedirectURI:   "https://*****:*****@example.com",
			EmailVerified: true,
			Groups:        []string{"a", "b"},
		},
	}

	if err := s.CreateAuthCode(a); err != nil {
		t.Fatalf("failed creating auth code: %v", err)
	}

	got, err := s.GetAuthCode(a.ID)
	if err != nil {
		t.Fatalf("failed to get auth req: %v", err)
	}
	if a.Expiry.Unix() != got.Expiry.Unix() {
		t.Errorf("auth code expiry did not match want=%s vs got=%s", a.Expiry, got.Expiry)
	}
	got.Expiry = a.Expiry // time fields do not compare well
	if diff := pretty.Compare(a, got); diff != "" {
		t.Errorf("auth code retrieved from storage did not match: %s", diff)
	}

	if err := s.DeleteAuthCode(a.ID); err != nil {
		t.Fatalf("delete auth code: %v", err)
	}

	_, err = s.GetAuthCode(a.ID)
	mustBeErrNotFound(t, "auth code", err)
}
Example #9
0
func testRefreshTokenCRUD(t *testing.T, s storage.Storage) {
	id := storage.NewID()
	refresh := storage.RefreshToken{
		RefreshToken: id,
		ClientID:     "client_id",
		ConnectorID:  "client_secret",
		Scopes:       []string{"openid", "email", "profile"},
		Claims: storage.Claims{
			UserID:        "1",
			Username:      "******",
			Email:         "*****@*****.**",
			EmailVerified: true,
			Groups:        []string{"a", "b"},
		},
	}
	if err := s.CreateRefresh(refresh); err != nil {
		t.Fatalf("create refresh token: %v", err)
	}

	getAndCompare := func(id string, want storage.RefreshToken) {
		gr, err := s.GetRefresh(id)
		if err != nil {
			t.Errorf("get refresh: %v", err)
			return
		}
		if diff := pretty.Compare(want, gr); diff != "" {
			t.Errorf("refresh token retrieved from storage did not match: %s", diff)
		}
	}

	getAndCompare(id, refresh)

	if err := s.DeleteRefresh(id); err != nil {
		t.Fatalf("failed to delete refresh request: %v", err)
	}

	if _, err := s.GetRefresh(id); err != storage.ErrNotFound {
		t.Errorf("after deleting refresh expected storage.ErrNotFound, got %v", err)
	}
}
Example #10
0
func testAuthRequestCRUD(t *testing.T, s storage.Storage) {
	a := storage.AuthRequest{
		ID:                  storage.NewID(),
		ClientID:            "foobar",
		ResponseTypes:       []string{"code"},
		Scopes:              []string{"openid", "email"},
		RedirectURI:         "https://*****:*****@example.com",
			EmailVerified: true,
			Groups:        []string{"a", "b"},
		},
	}

	identity := storage.Claims{Email: "foobar"}

	if err := s.CreateAuthRequest(a); err != nil {
		t.Fatalf("failed creating auth request: %v", err)
	}
	if err := s.UpdateAuthRequest(a.ID, func(old storage.AuthRequest) (storage.AuthRequest, error) {
		old.Claims = identity
		old.ConnectorID = "connID"
		return old, nil
	}); err != nil {
		t.Fatalf("failed to update auth request: %v", err)
	}

	got, err := s.GetAuthRequest(a.ID)
	if err != nil {
		t.Fatalf("failed to get auth req: %v", err)
	}
	if !reflect.DeepEqual(got.Claims, identity) {
		t.Fatalf("update failed, wanted identity=%#v got %#v", identity, got.Claims)
	}
}
Example #11
0
func testAuthRequestConcurrentUpdate(t *testing.T, s storage.Storage) {
	a := storage.AuthRequest{
		ID:                  storage.NewID(),
		ClientID:            "foobar",
		ResponseTypes:       []string{"code"},
		Scopes:              []string{"openid", "email"},
		RedirectURI:         "https://*****:*****@example.com",
			EmailVerified: true,
			Groups:        []string{"a", "b"},
		},
	}

	if err := s.CreateAuthRequest(a); err != nil {
		t.Fatalf("failed creating auth request: %v", err)
	}

	var err1, err2 error

	err1 = s.UpdateAuthRequest(a.ID, func(old storage.AuthRequest) (storage.AuthRequest, error) {
		old.State = "state 1"
		err2 = s.UpdateAuthRequest(a.ID, func(old storage.AuthRequest) (storage.AuthRequest, error) {
			old.State = "state 2"
			return old, nil
		})
		return old, nil
	})

	if (err1 == nil) == (err2 == nil) {
		t.Errorf("update auth request:\nupdate1: %v\nupdate2: %v\n", err1, err2)
	}
}
Example #12
0
func testGC(t *testing.T, s storage.Storage) {
	est, err := time.LoadLocation("America/New_York")
	if err != nil {
		t.Fatal(err)
	}
	pst, err := time.LoadLocation("America/Los_Angeles")
	if err != nil {
		t.Fatal(err)
	}

	expiry := time.Now().In(est)
	c := storage.AuthCode{
		ID:            storage.NewID(),
		ClientID:      "foobar",
		RedirectURI:   "https://*****:*****@example.com",
			EmailVerified: true,
			Groups:        []string{"a", "b"},
		},
	}

	if err := s.CreateAuthCode(c); err != nil {
		t.Fatalf("failed creating auth code: %v", err)
	}

	for _, tz := range []*time.Location{time.UTC, est, pst} {
		result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz))
		if err != nil {
			t.Errorf("garbage collection failed: %v", err)
		} else {
			if result.AuthCodes != 0 || result.AuthRequests != 0 {
				t.Errorf("expected no garbage collection results, got %#v", result)
			}
		}
		if _, err := s.GetAuthCode(c.ID); err != nil {
			t.Errorf("expected to be able to get auth code after GC: %v", err)
		}
	}

	if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil {
		t.Errorf("garbage collection failed: %v", err)
	} else if r.AuthCodes != 1 {
		t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthCodes)
	}

	if _, err := s.GetAuthCode(c.ID); err == nil {
		t.Errorf("expected auth code to be GC'd")
	} else if err != storage.ErrNotFound {
		t.Errorf("expected storage.ErrNotFound, got %v", err)
	}

	a := storage.AuthRequest{
		ID:                  storage.NewID(),
		ClientID:            "foobar",
		ResponseTypes:       []string{"code"},
		Scopes:              []string{"openid", "email"},
		RedirectURI:         "https://*****:*****@example.com",
			EmailVerified: true,
			Groups:        []string{"a", "b"},
		},
	}

	if err := s.CreateAuthRequest(a); err != nil {
		t.Fatalf("failed creating auth request: %v", err)
	}

	for _, tz := range []*time.Location{time.UTC, est, pst} {
		result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz))
		if err != nil {
			t.Errorf("garbage collection failed: %v", err)
		} else {
			if result.AuthCodes != 0 || result.AuthRequests != 0 {
				t.Errorf("expected no garbage collection results, got %#v", result)
			}
		}
		if _, err := s.GetAuthRequest(a.ID); err != nil {
			t.Errorf("expected to be able to get auth code after GC: %v", err)
		}
	}

	if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil {
		t.Errorf("garbage collection failed: %v", err)
	} else if r.AuthRequests != 1 {
		t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthRequests)
	}

	if _, err := s.GetAuthRequest(a.ID); err == nil {
		t.Errorf("expected auth code to be GC'd")
	} else if err != storage.ErrNotFound {
		t.Errorf("expected storage.ErrNotFound, got %v", err)
	}
}
Example #13
0
func testPasswordCRUD(t *testing.T, s storage.Storage) {
	// Use bcrypt.MinCost to keep the tests short.
	passwordHash, err := bcrypt.GenerateFromPassword([]byte("secret"), bcrypt.MinCost)
	if err != nil {
		t.Fatal(err)
	}

	password := storage.Password{
		Email:    "*****@*****.**",
		Hash:     passwordHash,
		Username: "******",
		UserID:   "foobar",
	}
	if err := s.CreatePassword(password); err != nil {
		t.Fatalf("create password token: %v", err)
	}

	getAndCompare := func(id string, want storage.Password) {
		gr, err := s.GetPassword(id)
		if err != nil {
			t.Errorf("get password %q: %v", id, err)
			return
		}
		if diff := pretty.Compare(want, gr); diff != "" {
			t.Errorf("password retrieved from storage did not match: %s", diff)
		}
	}

	getAndCompare("*****@*****.**", password)
	getAndCompare("*****@*****.**", password) // Emails should be case insensitive

	if err := s.UpdatePassword(password.Email, func(old storage.Password) (storage.Password, error) {
		old.Username = "******"
		return old, nil
	}); err != nil {
		t.Fatalf("failed to update auth request: %v", err)
	}

	password.Username = "******"
	getAndCompare("*****@*****.**", password)

	var passwordList []storage.Password
	passwordList = append(passwordList, password)

	listAndCompare := func(want []storage.Password) {
		passwords, err := s.ListPasswords()
		if err != nil {
			t.Errorf("list password: %v", err)
			return
		}
		sort.Sort(byEmail(want))
		sort.Sort(byEmail(passwords))
		if diff := pretty.Compare(want, passwords); diff != "" {
			t.Errorf("password list retrieved from storage did not match: %s", diff)
		}
	}

	listAndCompare(passwordList)

	if err := s.DeletePassword(password.Email); err != nil {
		t.Fatalf("failed to delete password: %v", err)
	}

	if _, err := s.GetPassword(password.Email); err != storage.ErrNotFound {
		t.Errorf("after deleting password expected storage.ErrNotFound, got %v", err)
	}

}
Example #14
0
func testClientCRUD(t *testing.T, s storage.Storage) {
	id := storage.NewID()
	c := storage.Client{
		ID:           id,
		Secret:       "foobar",
		RedirectURIs: []string{"foo://bar.com/", "https://auth.example.com"},
		Name:         "dex client",
		LogoURL:      "https://goo.gl/JIyzIC",
	}
	err := s.DeleteClient(id)
	mustBeErrNotFound(t, "client", err)

	if err := s.CreateClient(c); err != nil {
		t.Fatalf("create client: %v", err)
	}

	getAndCompare := func(id string, want storage.Client) {
		gc, err := s.GetClient(id)
		if err != nil {
			t.Errorf("get client: %v", err)
			return
		}
		if diff := pretty.Compare(want, gc); diff != "" {
			t.Errorf("client retrieved from storage did not match: %s", diff)
		}
	}

	getAndCompare(id, c)

	newSecret := "barfoo"
	err = s.UpdateClient(id, func(old storage.Client) (storage.Client, error) {
		old.Secret = newSecret
		return old, nil
	})
	if err != nil {
		t.Errorf("update client: %v", err)
	}
	c.Secret = newSecret
	getAndCompare(id, c)

	if err := s.DeleteClient(id); err != nil {
		t.Fatalf("delete client: %v", err)
	}

	_, err = s.GetClient(id)
	mustBeErrNotFound(t, "client", err)
}
Example #15
0
// parse the initial request from the OAuth2 client.
//
// For correctness the logic is largely copied from https://github.com/RangelReale/osin.
func parseAuthorizationRequest(s storage.Storage, supportedResponseTypes map[string]bool, r *http.Request) (req storage.AuthRequest, oauth2Err *authErr) {
	if err := r.ParseForm(); err != nil {
		return req, &authErr{"", "", errInvalidRequest, "Failed to parse request."}
	}

	redirectURI, err := url.QueryUnescape(r.Form.Get("redirect_uri"))
	if err != nil {
		return req, &authErr{"", "", errInvalidRequest, "No redirect_uri provided."}
	}
	state := r.FormValue("state")

	clientID := r.Form.Get("client_id")

	client, err := s.GetClient(clientID)
	if err != nil {
		if err == storage.ErrNotFound {
			description := fmt.Sprintf("Invalid client_id (%q).", clientID)
			return req, &authErr{"", "", errUnauthorizedClient, description}
		}
		log.Printf("Failed to get client: %v", err)
		return req, &authErr{"", "", errServerError, ""}
	}

	if !validateRedirectURI(client, redirectURI) {
		description := fmt.Sprintf("Unregistered redirect_uri (%q).", redirectURI)
		return req, &authErr{"", "", errInvalidRequest, description}
	}

	newErr := func(typ, format string, a ...interface{}) *authErr {
		return &authErr{state, redirectURI, typ, fmt.Sprintf(format, a...)}
	}

	scopes := strings.Fields(r.Form.Get("scope"))

	var (
		unrecognized  []string
		invalidScopes []string
	)
	hasOpenIDScope := false
	for _, scope := range scopes {
		switch scope {
		case scopeOpenID:
			hasOpenIDScope = true
		case scopeOfflineAccess, scopeEmail, scopeProfile, scopeGroups:
		default:
			peerID, ok := parseCrossClientScope(scope)
			if !ok {
				unrecognized = append(unrecognized, scope)
				continue
			}

			isTrusted, err := validateCrossClientTrust(s, clientID, peerID)
			if err != nil {
				return req, newErr(errServerError, "")
			}
			if !isTrusted {
				invalidScopes = append(invalidScopes, scope)
			}
		}
	}
	if !hasOpenIDScope {
		return req, newErr("invalid_scope", `Missing required scope(s) ["openid"].`)
	}
	if len(unrecognized) > 0 {
		return req, newErr("invalid_scope", "Unrecognized scope(s) %q", unrecognized)
	}
	if len(invalidScopes) > 0 {
		return req, newErr("invalid_scope", "Client can't request scope(s) %q", invalidScopes)
	}

	nonce := r.Form.Get("nonce")
	responseTypes := strings.Split(r.Form.Get("response_type"), " ")
	for _, responseType := range responseTypes {
		if !supportedResponseTypes[responseType] {
			return req, newErr("invalid_request", "Invalid response type %q", responseType)
		}

		switch responseType {
		case responseTypeCode:
		case responseTypeToken:
			// Implicit flow requires a nonce value.
			// https://openid.net/specs/openid-connect-core-1_0.html#ImplicitAuthRequest
			if nonce == "" {
				return req, newErr("invalid_request", "Response type 'token' requires a 'nonce' value.")
			}

			if redirectURI == redirectURIOOB {
				err := fmt.Sprintf("Cannot use response type 'token' with redirect_uri '%s'.", redirectURIOOB)
				return req, newErr("invalid_request", err)
			}
		default:
			return req, newErr("invalid_request", "Invalid response type %q", responseType)
		}
	}

	return storage.AuthRequest{
		ID:                  storage.NewID(),
		ClientID:            client.ID,
		State:               r.Form.Get("state"),
		Nonce:               nonce,
		ForceApprovalPrompt: r.Form.Get("approval_prompt") == "force",
		Scopes:              scopes,
		RedirectURI:         redirectURI,
		ResponseTypes:       responseTypes,
	}, nil
}
Example #16
0
func testGC(t *testing.T, s storage.Storage) {
	n := time.Now().UTC()
	c := storage.AuthCode{
		ID:            storage.NewID(),
		ClientID:      "foobar",
		RedirectURI:   "https://*****:*****@example.com",
			EmailVerified: true,
			Groups:        []string{"a", "b"},
		},
	}

	if err := s.CreateAuthCode(c); err != nil {
		t.Fatalf("failed creating auth code: %v", err)
	}

	if _, err := s.GarbageCollect(n); err != nil {
		t.Errorf("garbage collection failed: %v", err)
	}
	if _, err := s.GetAuthCode(c.ID); err != nil {
		t.Errorf("expected to be able to get auth code after GC: %v", err)
	}

	if r, err := s.GarbageCollect(n.Add(time.Minute)); err != nil {
		t.Errorf("garbage collection failed: %v", err)
	} else if r.AuthCodes != 1 {
		t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthCodes)
	}

	if _, err := s.GetAuthCode(c.ID); err == nil {
		t.Errorf("expected auth code to be GC'd")
	} else if err != storage.ErrNotFound {
		t.Errorf("expected storage.ErrNotFound, got %v", err)
	}

	a := storage.AuthRequest{
		ID:                  storage.NewID(),
		ClientID:            "foobar",
		ResponseTypes:       []string{"code"},
		Scopes:              []string{"openid", "email"},
		RedirectURI:         "https://*****:*****@example.com",
			EmailVerified: true,
			Groups:        []string{"a", "b"},
		},
	}

	if err := s.CreateAuthRequest(a); err != nil {
		t.Fatalf("failed creating auth request: %v", err)
	}

	if _, err := s.GarbageCollect(n); err != nil {
		t.Errorf("garbage collection failed: %v", err)
	}
	if _, err := s.GetAuthRequest(a.ID); err != nil {
		t.Errorf("expected to be able to get auth code after GC: %v", err)
	}

	if r, err := s.GarbageCollect(n.Add(time.Minute)); err != nil {
		t.Errorf("garbage collection failed: %v", err)
	} else if r.AuthRequests != 1 {
		t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthRequests)
	}

	if _, err := s.GetAuthRequest(a.ID); err == nil {
		t.Errorf("expected auth code to be GC'd")
	} else if err != storage.ErrNotFound {
		t.Errorf("expected storage.ErrNotFound, got %v", err)
	}
}