Example #1
0
func TestDBRefreshRepoRevoke(t *testing.T) {
	r := db.NewRefreshTokenRepo(connect(t))

	token, err := r.Create("user-foo", "client-foo")
	if err != nil {
		t.Fatalf("Unexpected error: %v", err)
	}

	badTokenPayload, err := refresh.DefaultRefreshTokenGenerator()
	if err != nil {
		t.Fatalf("Unexpected error: %v", err)
	}
	tokenWithBadID := "404" + token[1:]
	tokenWithBadPayload := buildRefreshToken(1, badTokenPayload)

	tests := []struct {
		token  string
		userID string
		err    error
	}{
		{
			"invalid-token-format",
			"user-foo",
			refresh.ErrorInvalidToken,
		},
		{
			"1/invalid-base64-encoded-format",
			"user-foo",
			refresh.ErrorInvalidToken,
		},
		{
			token + "corrupted-token-payload",
			"user-foo",
			refresh.ErrorInvalidToken,
		},
		{
			// The token's ID is invalid.
			tokenWithBadID,
			"user-foo",
			refresh.ErrorInvalidToken,
		},
		{
			// The token's payload is invalid.
			tokenWithBadPayload,
			"user-foo",
			refresh.ErrorInvalidToken,
		},
		{
			token,
			"invalid-user",
			refresh.ErrorInvalidUserID,
		},
		{
			token,
			"user-foo",
			nil,
		},
	}

	for i, tt := range tests {
		if err := r.Revoke(tt.userID, tt.token); err != tt.err {
			t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err)
		}
	}
}
Example #2
0
func TestDBRefreshRepoVerify(t *testing.T) {
	r := db.NewRefreshTokenRepo(connect(t))

	token, err := r.Create("user-foo", "client-foo")
	if err != nil {
		t.Fatalf("Unexpected error: %v", err)
	}

	badTokenPayload, err := refresh.DefaultRefreshTokenGenerator()
	if err != nil {
		t.Fatalf("Unexpected error: %v", err)
	}
	tokenWithBadID := "404" + token[1:]
	tokenWithBadPayload := buildRefreshToken(1, badTokenPayload)

	tests := []struct {
		token    string
		creds    oidc.ClientCredentials
		err      error
		expected string
	}{
		{
			"invalid-token-format",
			oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
			refresh.ErrorInvalidToken,
			"",
		},
		{
			"b/invalid-base64-encoded-format",
			oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
			refresh.ErrorInvalidToken,
			"",
		},
		{
			"1/invalid-base64-encoded-format",
			oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
			refresh.ErrorInvalidToken,
			"",
		},
		{
			token + "corrupted-token-payload",
			oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
			refresh.ErrorInvalidToken,
			"",
		},
		{
			// The token's ID content is invalid.
			tokenWithBadID,
			oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
			refresh.ErrorInvalidToken,
			"",
		},
		{
			// The token's payload content is invalid.
			tokenWithBadPayload,
			oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
			refresh.ErrorInvalidToken,
			"",
		},
		{
			token,
			oidc.ClientCredentials{ID: "invalid-client", Secret: "secret-foo"},
			refresh.ErrorInvalidClientID,
			"",
		},
		{
			token,
			oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
			nil,
			"user-foo",
		},
	}

	for i, tt := range tests {
		result, err := r.Verify(tt.creds.ID, tt.token)
		if err != tt.err {
			t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err)
		}
		if result != tt.expected {
			t.Errorf("Case #%d: expected: %v, got: %v", i, tt.expected, result)
		}
	}
}