Exemplo n.º 1
0
func TestAfterRetryRefreshCreds(t *testing.T) {
	os.Clearenv()
	credProvider := &mockCredsProvider{}

	svc := awstesting.NewClient(&aws.Config{
		Credentials: credentials.NewCredentials(credProvider),
		MaxRetries:  aws.Int(1),
	})

	svc.Handlers.Clear()
	svc.Handlers.ValidateResponse.PushBack(func(r *request.Request) {
		r.Error = awserr.New("UnknownError", "", nil)
		r.HTTPResponse = &http.Response{StatusCode: 400}
	})
	svc.Handlers.UnmarshalError.PushBack(func(r *request.Request) {
		r.Error = awserr.New("ExpiredTokenException", "", nil)
	})
	svc.Handlers.AfterRetry.PushBackNamed(corehandlers.AfterRetryHandler)

	assert.True(t, svc.Config.Credentials.IsExpired(), "Expect to start out expired")
	assert.False(t, credProvider.retrieveCalled)

	req := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
	req.Send()

	assert.True(t, svc.Config.Credentials.IsExpired())
	assert.False(t, credProvider.retrieveCalled)

	_, err := svc.Config.Credentials.Get()
	assert.NoError(t, err)
	assert.True(t, credProvider.retrieveCalled)
}
Exemplo n.º 2
0
func copyConfig(config *Config) *aws.Config {
	if config == nil {
		config = &Config{}
	}
	c := &aws.Config{
		Credentials: credentials.AnonymousCredentials,
		Endpoint:    config.Endpoint,
		HTTPClient:  config.HTTPClient,
		Logger:      config.Logger,
		LogLevel:    config.LogLevel,
		MaxRetries:  config.MaxRetries,
	}

	if c.HTTPClient == nil {
		c.HTTPClient = http.DefaultClient
	}
	if c.Logger == nil {
		c.Logger = aws.NewDefaultLogger()
	}
	if c.LogLevel == nil {
		c.LogLevel = aws.LogLevel(aws.LogOff)
	}
	if c.MaxRetries == nil {
		c.MaxRetries = aws.Int(DefaultRetries)
	}

	return c
}
Exemplo n.º 3
0
// test that the request is retried after the credentials are expired.
func TestRequestRecoverExpiredCreds(t *testing.T) {
	reqNum := 0
	reqs := []http.Response{
		{StatusCode: 400, Body: body(`{"__type":"ExpiredTokenException","message":"expired token"}`)},
		{StatusCode: 200, Body: body(`{"data":"valid"}`)},
	}

	s := service.New(&aws.Config{MaxRetries: aws.Int(10), Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "")})
	s.Handlers.Validate.Clear()
	s.Handlers.Unmarshal.PushBack(unmarshal)
	s.Handlers.UnmarshalError.PushBack(unmarshalError)

	credExpiredBeforeRetry := false
	credExpiredAfterRetry := false

	s.Handlers.AfterRetry.PushBack(func(r *request.Request) {
		credExpiredAfterRetry = r.Service.Config.Credentials.IsExpired()
	})

	s.Handlers.Sign.Clear()
	s.Handlers.Sign.PushBack(func(r *request.Request) {
		r.Service.Config.Credentials.Get()
	})
	s.Handlers.Send.Clear() // mock sending
	s.Handlers.Send.PushBack(func(r *request.Request) {
		r.HTTPResponse = &reqs[reqNum]
		reqNum++
	})
	out := &testData{}
	r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out)
	err := r.Send()
	assert.Nil(t, err)

	assert.False(t, credExpiredBeforeRetry, "Expect valid creds before retry check")
	assert.True(t, credExpiredAfterRetry, "Expect expired creds after retry check")
	assert.False(t, s.Config.Credentials.IsExpired(), "Expect valid creds after cred expired recovery")

	assert.Equal(t, 1, int(r.RetryCount))
	assert.Equal(t, "valid", out.Data)
}
Exemplo n.º 4
0
func copyConfig(config *Config) *aws.Config {
	if config == nil {
		config = &Config{}
	}
	c := &aws.Config{
		Credentials: credentials.AnonymousCredentials,
		Endpoint:    config.Endpoint,
		HTTPClient:  config.HTTPClient,
		Logger:      config.Logger,
		LogLevel:    config.LogLevel,
		MaxRetries:  config.MaxRetries,
	}

	if c.HTTPClient == nil {
		c.HTTPClient = &http.Client{
			Transport: &http.Transport{
				Proxy: http.ProxyFromEnvironment,
				Dial: (&net.Dialer{
					// use a shorter timeout than default because the metadata
					// service is local if it is running, and to fail faster
					// if not running on an ec2 instance.
					Timeout:   5 * time.Second,
					KeepAlive: 30 * time.Second,
				}).Dial,
				TLSHandshakeTimeout: 10 * time.Second,
			},
		}
	}
	if c.Logger == nil {
		c.Logger = aws.NewDefaultLogger()
	}
	if c.LogLevel == nil {
		c.LogLevel = aws.LogLevel(aws.LogOff)
	}
	if c.MaxRetries == nil {
		c.MaxRetries = aws.Int(DefaultRetries)
	}

	return c
}