Exemple #1
0
func (sb signerBuilder) BuildSigner() signer {
	endpoint := "https://" + sb.ServiceName + "." + sb.Region + ".amazonaws.com"
	var req *http.Request
	if sb.Method == "POST" {
		body := []byte(sb.Query.Encode())
		reader := bytes.NewReader(body)
		req, _ = http.NewRequest(sb.Method, endpoint, reader)
		req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
		req.Header.Add("Content-Length", string(len(body)))
	} else {
		req, _ = http.NewRequest(sb.Method, endpoint, nil)
		req.URL.RawQuery = sb.Query.Encode()
	}

	sig := signer{
		Request: req,
		Time:    sb.SignTime,
		Credentials: credentials.NewStaticCredentials(
			"AKID",
			"SECRET",
			sb.SessionToken),
	}

	if os.Getenv("DEBUG") != "" {
		sig.Debug = aws.LogDebug
		sig.Logger = aws.NewDefaultLogger()
	}

	return sig
}
Exemple #2
0
func TestGet(t *testing.T) {
	assert := assert.New(t)
	svc := awstesting.NewClient(&aws.Config{
		Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "SESSION"),
		Region:      aws.String("ap-southeast-2"),
	})
	r := svc.NewRequest(
		&request.Operation{
			Name:       "OpName",
			HTTPMethod: "GET",
			HTTPPath:   "/",
		},
		nil,
		nil,
	)

	r.Build()
	assert.Equal("GET", r.HTTPRequest.Method)
	assert.Equal("", r.HTTPRequest.URL.Query().Get("Signature"))

	Sign(r)
	assert.NoError(r.Error)
	t.Logf("Signature: %s", r.HTTPRequest.URL.Query().Get("Signature"))
	assert.NotEqual("", r.HTTPRequest.URL.Query().Get("Signature"))
}
Exemple #3
0
func TestResignRequestExpiredCreds(t *testing.T) {
	creds := credentials.NewStaticCredentials("AKID", "SECRET", "SESSION")
	svc := awstesting.NewClient(&aws.Config{Credentials: creds})
	r := svc.NewRequest(
		&request.Operation{
			Name:       "BatchGetItem",
			HTTPMethod: "POST",
			HTTPPath:   "/",
		},
		nil,
		nil,
	)
	Sign(r)
	querySig := r.HTTPRequest.Header.Get("Authorization")

	creds.Expire()

	Sign(r)
	assert.NotEqual(t, querySig, r.HTTPRequest.Header.Get("Authorization"))
}
Exemple #4
0
func buildSigner(serviceName string, region string, signTime time.Time, expireTime time.Duration, body string) signer {
	endpoint := "https://" + serviceName + "." + region + ".amazonaws.com"
	reader := strings.NewReader(body)
	req, _ := http.NewRequest("POST", endpoint, reader)
	req.URL.Opaque = "//example.org/bucket/key-._~,!@#$%^&*()"
	req.Header.Add("X-Amz-Target", "prefix.Operation")
	req.Header.Add("Content-Type", "application/x-amz-json-1.0")
	req.Header.Add("Content-Length", string(len(body)))
	req.Header.Add("X-Amz-Meta-Other-Header", "some-value=!@#$%^&* (+)")

	return signer{
		Request:     req,
		Time:        signTime,
		ExpireTime:  expireTime,
		Query:       req.URL.Query(),
		Body:        reader,
		ServiceName: serviceName,
		Region:      region,
		Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "SESSION"),
	}
}
Exemple #5
0
func TestIgnoreResignRequestWithValidCreds(t *testing.T) {
	svc := awstesting.NewClient(&aws.Config{
		Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "SESSION"),
		Region:      aws.String("us-west-2"),
	})
	r := svc.NewRequest(
		&request.Operation{
			Name:       "BatchGetItem",
			HTTPMethod: "POST",
			HTTPPath:   "/",
		},
		nil,
		nil,
	)

	Sign(r)
	sig := r.HTTPRequest.Header.Get("Authorization")

	Sign(r)
	assert.Equal(t, sig, r.HTTPRequest.Header.Get("Authorization"))
}
// test that the request is retried after the credentials are expired.
func TestRequestRecoverExpiredCreds(t *testing.T) {
	reqNum := 0
	reqs := []http.Response{
		{StatusCode: 400, Body: body(`{"__type":"ExpiredTokenException","message":"expired token"}`)},
		{StatusCode: 200, Body: body(`{"data":"valid"}`)},
	}

	s := awstesting.NewClient(&aws.Config{MaxRetries: aws.Int(10), Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "")})
	s.Handlers.Validate.Clear()
	s.Handlers.Unmarshal.PushBack(unmarshal)
	s.Handlers.UnmarshalError.PushBack(unmarshalError)

	credExpiredBeforeRetry := false
	credExpiredAfterRetry := false

	s.Handlers.AfterRetry.PushBack(func(r *request.Request) {
		credExpiredAfterRetry = r.Config.Credentials.IsExpired()
	})

	s.Handlers.Sign.Clear()
	s.Handlers.Sign.PushBack(func(r *request.Request) {
		r.Config.Credentials.Get()
	})
	s.Handlers.Send.Clear() // mock sending
	s.Handlers.Send.PushBack(func(r *request.Request) {
		r.HTTPResponse = &reqs[reqNum]
		reqNum++
	})
	out := &testData{}
	r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out)
	err := r.Send()
	assert.Nil(t, err)

	assert.False(t, credExpiredBeforeRetry, "Expect valid creds before retry check")
	assert.True(t, credExpiredAfterRetry, "Expect expired creds after retry check")
	assert.False(t, s.Config.Credentials.IsExpired(), "Expect valid creds after cred expired recovery")

	assert.Equal(t, 1, int(r.RetryCount))
	assert.Equal(t, "valid", out.Data)
}
Exemple #7
0
// Package unit performs initialization and validation for unit tests
package unit

import (
	"github.com/bluet-deps/aws-sdk-go/aws"
	"github.com/bluet-deps/aws-sdk-go/aws/credentials"
	"github.com/bluet-deps/aws-sdk-go/aws/session"
)

// Session is a shared session for unit tests to use.
var Session = session.New(aws.NewConfig().
	WithCredentials(credentials.NewStaticCredentials("AKID", "SECRET", "SESSION")).
	WithRegion("mock-region"))
Exemple #8
0
package aws

import (
	"net/http"
	"reflect"
	"testing"

	"github.com/bluet-deps/aws-sdk-go/aws/credentials"
)

var testCredentials = credentials.NewStaticCredentials("AKID", "SECRET", "SESSION")

var copyTestConfig = Config{
	Credentials:             testCredentials,
	Endpoint:                String("CopyTestEndpoint"),
	Region:                  String("COPY_TEST_AWS_REGION"),
	DisableSSL:              Bool(true),
	HTTPClient:              http.DefaultClient,
	LogLevel:                LogLevel(LogDebug),
	Logger:                  NewDefaultLogger(),
	MaxRetries:              Int(3),
	DisableParamValidation:  Bool(true),
	DisableComputeChecksums: Bool(true),
	S3ForcePathStyle:        Bool(true),
}

func TestCopy(t *testing.T) {
	want := copyTestConfig
	got := copyTestConfig.Copy()
	if !reflect.DeepEqual(*got, want) {
		t.Errorf("Copy() = %+v", got)