예제 #1
0
// GetOrCreateRefreshToken retrieves an existing refresh token, if expired,
// the token gets deleted and new refresh token is created
func (s *Service) GetOrCreateRefreshToken(client *Client, user *User, expiresIn int, scope string) (*RefreshToken, error) {
	// Try to fetch an existing refresh token first
	refreshToken := new(RefreshToken)
	clientID := util.PositiveIntOrNull(int64(client.ID))
	userID := util.PositiveIntOrNull(0) // user ID can be NULL
	if user != nil {
		userID = util.PositiveIntOrNull(int64(user.ID))
	}
	found := !s.db.Where(RefreshToken{
		ClientID: clientID,
		UserID:   userID,
	}).Preload("Client").Preload("User").First(refreshToken).RecordNotFound()

	// Check if the token is expired, if found
	var expired bool
	if found {
		expired = time.Now().After(refreshToken.ExpiresAt)
	}

	// If the refresh token has expired, delete it
	if expired {
		s.db.Unscoped().Delete(refreshToken)
	}

	// Create a new refresh token if it expired or was not found
	if expired || !found {
		refreshToken = NewRefreshToken(client, user, expiresIn, scope)
		if err := s.db.Create(refreshToken).Error; err != nil {
			return nil, err
		}
	}

	return refreshToken, nil
}
예제 #2
0
func TestPositiveIntOrNull(t *testing.T) {
	var (
		nullInt sql.NullInt64
		value   driver.Value
		err     error
	)

	// When the number is negative
	nullInt = util.PositiveIntOrNull(-1)

	// nullInt.Valid should be false
	assert.False(t, nullInt.Valid)

	// nullInt.Value() should return nil
	value, err = nullInt.Value()
	assert.Nil(t, err)
	assert.Nil(t, value)

	// When the number is greater than zero
	nullInt = util.PositiveIntOrNull(1)

	// nullInt.Valid should be true
	assert.True(t, nullInt.Valid)

	// nullInt.Value() should return the integer
	value, err = nullInt.Value()
	assert.Nil(t, err)
	assert.Equal(t, int64(1), value)
}
예제 #3
0
// deleteExpiredAccessTokens deletes expired access tokens
func (s *Service) deleteExpiredAccessTokens(client *Client, user *User) {
	clientID := util.PositiveIntOrNull(int64(client.ID))
	userID := util.PositiveIntOrNull(0) // user ID can be NULL
	if user != nil {
		userID = util.PositiveIntOrNull(int64(user.ID))
	}
	s.db.Unscoped().Where(AccessToken{
		ClientID: clientID,
		UserID:   userID,
	}).Where("expires_at <= ?", time.Now()).Delete(new(AccessToken))
}
예제 #4
0
func TestIntOrNull(t *testing.T) {
	nullInt := util.PositiveIntOrNull(1)
	assert.True(t, nullInt.Valid)

	value, err := nullInt.Value()
	assert.Nil(t, err)
	assert.Equal(t, int64(1), value)
}
예제 #5
0
// NewAuthorizationCode creates new AuthorizationCode instance
func NewAuthorizationCode(client *Client, user *User, expiresIn int, redirectURI, scope string) *AuthorizationCode {
	clientID := util.PositiveIntOrNull(int64(client.ID))
	userID := util.PositiveIntOrNull(int64(user.ID))
	authorizationCode := &AuthorizationCode{
		ClientID:    clientID,
		UserID:      userID,
		Code:        uuid.New(),
		ExpiresAt:   time.Now().Add(time.Duration(expiresIn) * time.Second),
		RedirectURI: util.StringOrNull(redirectURI),
		Scope:       scope,
	}
	if clientID.Valid {
		authorizationCode.Client = client
	}
	if userID.Valid {
		authorizationCode.User = user
	}
	return authorizationCode
}
예제 #6
0
// NewAccessToken creates new AccessToken instance
func NewAccessToken(client *Client, user *User, expiresIn int, scope string) *AccessToken {
	clientID := util.PositiveIntOrNull(int64(client.ID))
	userID := util.PositiveIntOrNull(0) // user ID can be NULL
	if user != nil {
		userID = util.PositiveIntOrNull(int64(user.ID))
	}
	accessToken := &AccessToken{
		ClientID:  clientID,
		UserID:    userID,
		Token:     uuid.New(),
		ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second),
		Scope:     scope,
	}
	if clientID.Valid {
		accessToken.Client = client
	}
	if userID.Valid {
		accessToken.User = user
	}
	return accessToken
}
예제 #7
0
// GetValidRefreshToken returns a valid non expired refresh token
func (s *Service) GetValidRefreshToken(token string, client *Client) (*RefreshToken, error) {
	// Fetch the refresh token from the database
	refreshToken := new(RefreshToken)
	notFound := s.db.Where(RefreshToken{
		ClientID: util.PositiveIntOrNull(int64(client.ID)),
	}).Where("token = ?", token).Preload("Client").Preload("User").
		First(refreshToken).RecordNotFound()

	// Not found
	if notFound {
		return nil, ErrRefreshTokenNotFound
	}

	// Check the refresh token hasn't expired
	if time.Now().After(refreshToken.ExpiresAt) {
		return nil, ErrRefreshTokenExpired
	}

	return refreshToken, nil
}