// Retrieve generates a new set of temporary credentials using STS. func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) { // Apply defaults where parameters are not set. if p.RoleSessionName == "" { // Try to work out a role name that will hopefully end up unique. p.RoleSessionName = fmt.Sprintf("%d", time.Now().UTC().UnixNano()) } if p.Duration == 0 { // Expire as often as AWS permits. p.Duration = DefaultDuration } roleOutput, err := p.Client.AssumeRole(&sts.AssumeRoleInput{ DurationSeconds: aws.Int64(int64(p.Duration / time.Second)), RoleArn: aws.String(p.RoleARN), RoleSessionName: aws.String(p.RoleSessionName), ExternalId: p.ExternalID, }) if err != nil { return credentials.Value{}, err } // We will proactively generate new credentials before they expire. p.SetExpiration(*roleOutput.Credentials.Expiration, p.ExpiryWindow) return credentials.Value{ AccessKeyID: *roleOutput.Credentials.AccessKeyId, SecretAccessKey: *roleOutput.Credentials.SecretAccessKey, SessionToken: *roleOutput.Credentials.SessionToken, }, nil }
func TestPresignHandler(t *testing.T) { svc := s3.New(unit.Session) req, _ := svc.PutObjectRequest(&s3.PutObjectInput{ Bucket: aws.String("bucket"), Key: aws.String("key"), ContentDisposition: aws.String("a+b c$d"), ACL: aws.String("public-read"), }) req.Time = time.Unix(0, 0) urlstr, err := req.Presign(5 * time.Minute) assert.NoError(t, err) expectedDate := "19700101T000000Z" expectedHeaders := "host;x-amz-acl" expectedSig := "7edcb4e3a1bf12f4989018d75acbe3a7f03df24bd6f3112602d59fc551f0e4e2" expectedCred := "AKID/19700101/mock-region/s3/aws4_request" u, _ := url.Parse(urlstr) urlQ := u.Query() assert.Equal(t, expectedSig, urlQ.Get("X-Amz-Signature")) assert.Equal(t, expectedCred, urlQ.Get("X-Amz-Credential")) assert.Equal(t, expectedHeaders, urlQ.Get("X-Amz-SignedHeaders")) assert.Equal(t, expectedDate, urlQ.Get("X-Amz-Date")) assert.Equal(t, "300", urlQ.Get("X-Amz-Expires")) assert.NotContains(t, urlstr, "+") // + encoded as %20 }
func TestUnsignedRequest_AssumeRoleWithWebIdentity(t *testing.T) { req, _ := svc.AssumeRoleWithWebIdentityRequest(&sts.AssumeRoleWithWebIdentityInput{ RoleArn: aws.String("ARN01234567890123456789"), RoleSessionName: aws.String("SESSION"), WebIdentityToken: aws.String("TOKEN"), }) err := req.Sign() assert.NoError(t, err) assert.Equal(t, "", req.HTTPRequest.Header.Get("Authorization")) }
func TestUnsignedRequest_AssumeRoleWithSAML(t *testing.T) { req, _ := svc.AssumeRoleWithSAMLRequest(&sts.AssumeRoleWithSAMLInput{ PrincipalArn: aws.String("ARN01234567890123456789"), RoleArn: aws.String("ARN01234567890123456789"), SAMLAssertion: aws.String("ASSERT"), }) err := req.Sign() assert.NoError(t, err) assert.Equal(t, "", req.HTTPRequest.Header.Get("Authorization")) }
func (s *stubSTS) AssumeRole(input *sts.AssumeRoleInput) (*sts.AssumeRoleOutput, error) { expiry := time.Now().Add(60 * time.Minute) return &sts.AssumeRoleOutput{ Credentials: &sts.Credentials{ // Just reflect the role arn to the provider. AccessKeyId: input.RoleArn, SecretAccessKey: aws.String("assumedSecretAccessKey"), SessionToken: aws.String("assumedSessionToken"), Expiration: &expiry, }, }, nil }
// Use DynamoDB methods for simplicity func TestPaginationEachPage(t *testing.T) { db := dynamodb.New(unit.Session) tokens, pages, numPages, gotToEnd := []string{}, []string{}, 0, false reqNum := 0 resps := []*dynamodb.ListTablesOutput{ {TableNames: []*string{aws.String("Table1"), aws.String("Table2")}, LastEvaluatedTableName: aws.String("Table2")}, {TableNames: []*string{aws.String("Table3"), aws.String("Table4")}, LastEvaluatedTableName: aws.String("Table4")}, {TableNames: []*string{aws.String("Table5")}}, } db.Handlers.Send.Clear() // mock sending db.Handlers.Unmarshal.Clear() db.Handlers.UnmarshalMeta.Clear() db.Handlers.ValidateResponse.Clear() db.Handlers.Build.PushBack(func(r *request.Request) { in := r.Params.(*dynamodb.ListTablesInput) if in == nil { tokens = append(tokens, "") } else if in.ExclusiveStartTableName != nil { tokens = append(tokens, *in.ExclusiveStartTableName) } }) db.Handlers.Unmarshal.PushBack(func(r *request.Request) { r.Data = resps[reqNum] reqNum++ }) params := &dynamodb.ListTablesInput{Limit: aws.Int64(2)} req, _ := db.ListTablesRequest(params) err := req.EachPage(func(p interface{}, last bool) bool { numPages++ for _, t := range p.(*dynamodb.ListTablesOutput).TableNames { pages = append(pages, *t) } if last { if gotToEnd { assert.Fail(t, "last=true happened twice") } gotToEnd = true } return true }) assert.Equal(t, []string{"Table2", "Table4"}, tokens) assert.Equal(t, []string{"Table1", "Table2", "Table3", "Table4", "Table5"}, pages) assert.Equal(t, 3, numPages) assert.True(t, gotToEnd) assert.Nil(t, err) }
func TestNoErrors(t *testing.T) { input := &StructShape{ RequiredList: []*ConditionalStructShape{}, RequiredMap: map[string]*ConditionalStructShape{ "key1": {Name: aws.String("Name")}, "key2": {Name: aws.String("Name")}, }, RequiredBool: aws.Bool(true), OptionalStruct: &ConditionalStructShape{Name: aws.String("Name")}, } req := testSvc.NewRequest(&request.Operation{}, input, nil) corehandlers.ValidateParametersHandler.Fn(req) require.NoError(t, req.Error) }
func TestEC2RoleProviderExpiryWindowIsExpired(t *testing.T) { server := initTestServer("2014-12-16T01:51:37Z", false) defer server.Close() p := &ec2rolecreds.EC2RoleProvider{ Client: ec2metadata.New(session.New(), &aws.Config{Endpoint: aws.String(server.URL + "/latest")}), ExpiryWindow: time.Hour * 1, } p.CurrentTime = func() time.Time { return time.Date(2014, 12, 15, 0, 51, 37, 0, time.UTC) } assert.True(t, p.IsExpired(), "Expect creds to be expired before retrieve.") _, err := p.Retrieve() assert.Nil(t, err, "Expect no error, %v", err) assert.False(t, p.IsExpired(), "Expect creds to not be expired after retrieve.") p.CurrentTime = func() time.Time { return time.Date(2014, 12, 16, 0, 55, 37, 0, time.UTC) } assert.True(t, p.IsExpired(), "Expect creds to be expired.") }
func TestNewDefaultSession(t *testing.T) { s := session.New(&aws.Config{Region: aws.String("region")}) assert.Equal(t, "region", *s.Config.Region) assert.Equal(t, http.DefaultClient, s.Config.HTTPClient) assert.NotNil(t, s.Config.Logger) assert.Equal(t, aws.LogOff, *s.Config.LogLevel) }
// Use S3 for simplicity func TestPaginationTruncation(t *testing.T) { client := s3.New(unit.Session) reqNum := 0 resps := []*s3.ListObjectsOutput{ {IsTruncated: aws.Bool(true), Contents: []*s3.Object{{Key: aws.String("Key1")}}}, {IsTruncated: aws.Bool(true), Contents: []*s3.Object{{Key: aws.String("Key2")}}}, {IsTruncated: aws.Bool(false), Contents: []*s3.Object{{Key: aws.String("Key3")}}}, {IsTruncated: aws.Bool(true), Contents: []*s3.Object{{Key: aws.String("Key4")}}}, } client.Handlers.Send.Clear() // mock sending client.Handlers.Unmarshal.Clear() client.Handlers.UnmarshalMeta.Clear() client.Handlers.ValidateResponse.Clear() client.Handlers.Unmarshal.PushBack(func(r *request.Request) { r.Data = resps[reqNum] reqNum++ }) params := &s3.ListObjectsInput{Bucket: aws.String("bucket")} results := []string{} err := client.ListObjectsPages(params, func(p *s3.ListObjectsOutput, last bool) bool { results = append(results, *p.Contents[0].Key) return true }) assert.Equal(t, []string{"Key1", "Key2", "Key3"}, results) assert.Nil(t, err) // Try again without truncation token at all reqNum = 0 resps[1].IsTruncated = nil resps[2].IsTruncated = aws.Bool(true) results = []string{} err = client.ListObjectsPages(params, func(p *s3.ListObjectsOutput, last bool) bool { results = append(results, *p.Contents[0].Key) return true }) assert.Equal(t, []string{"Key1", "Key2"}, results) assert.Nil(t, err) }
func TestDeepEqual(t *testing.T) { cases := []struct { a, b interface{} equal bool }{ {"a", "a", true}, {"a", "b", false}, {"a", aws.String(""), false}, {"a", nil, false}, {"a", aws.String("a"), true}, {(*bool)(nil), (*bool)(nil), true}, {(*bool)(nil), (*string)(nil), false}, {nil, nil, true}, } for i, c := range cases { assert.Equal(t, c.equal, awsutil.DeepEqual(c.a, c.b), "%d, a:%v b:%v, %t", i, c.a, c.b, c.equal) } }
func ExampleSTS_GetSessionToken() { svc := sts.New(session.New()) params := &sts.GetSessionTokenInput{ DurationSeconds: aws.Int64(1), SerialNumber: aws.String("serialNumberType"), TokenCode: aws.String("tokenCodeType"), } resp, err := svc.GetSessionToken(params) if err != nil { // Print the error, cast err to awserr.Error to get the Code and // Message from an error. fmt.Println(err.Error()) return } // Pretty-print the response data. fmt.Println(resp) }
func ExampleSTS_GetFederationToken() { svc := sts.New(session.New()) params := &sts.GetFederationTokenInput{ Name: aws.String("userNameType"), // Required DurationSeconds: aws.Int64(1), Policy: aws.String("sessionPolicyDocumentType"), } resp, err := svc.GetFederationToken(params) if err != nil { // Print the error, cast err to awserr.Error to get the Code and // Message from an error. fmt.Println(err.Error()) return } // Pretty-print the response data. fmt.Println(resp) }
func TestRequestUserAgent(t *testing.T) { s := awstesting.NewClient(&aws.Config{Region: aws.String("us-east-1")}) // s.Handlers.Validate.Clear() req := s.NewRequest(&request.Operation{Name: "Operation"}, nil, &testData{}) req.HTTPRequest.Header.Set("User-Agent", "foo/bar") assert.NoError(t, req.Build()) expectUA := fmt.Sprintf("foo/bar %s/%s (%s; %s; %s)", aws.SDKName, aws.SDKVersion, runtime.Version(), runtime.GOOS, runtime.GOARCH) assert.Equal(t, expectUA, req.HTTPRequest.Header.Get("User-Agent")) }
// Use DynamoDB methods for simplicity func TestPaginationEarlyExit(t *testing.T) { db := dynamodb.New(unit.Session) numPages, gotToEnd := 0, false reqNum := 0 resps := []*dynamodb.ListTablesOutput{ {TableNames: []*string{aws.String("Table1"), aws.String("Table2")}, LastEvaluatedTableName: aws.String("Table2")}, {TableNames: []*string{aws.String("Table3"), aws.String("Table4")}, LastEvaluatedTableName: aws.String("Table4")}, {TableNames: []*string{aws.String("Table5")}}, } db.Handlers.Send.Clear() // mock sending db.Handlers.Unmarshal.Clear() db.Handlers.UnmarshalMeta.Clear() db.Handlers.ValidateResponse.Clear() db.Handlers.Unmarshal.PushBack(func(r *request.Request) { r.Data = resps[reqNum] reqNum++ }) params := &dynamodb.ListTablesInput{Limit: aws.Int64(2)} err := db.ListTablesPages(params, func(p *dynamodb.ListTablesOutput, last bool) bool { numPages++ if numPages == 2 { return false } if last { if gotToEnd { assert.Fail(t, "last=true happened twice") } gotToEnd = true } return true }) assert.Equal(t, 2, numPages) assert.False(t, gotToEnd) assert.Nil(t, err) }
func ExampleSTS_AssumeRoleWithSAML() { svc := sts.New(session.New()) params := &sts.AssumeRoleWithSAMLInput{ PrincipalArn: aws.String("arnType"), // Required RoleArn: aws.String("arnType"), // Required SAMLAssertion: aws.String("SAMLAssertionType"), // Required DurationSeconds: aws.Int64(1), Policy: aws.String("sessionPolicyDocumentType"), } resp, err := svc.AssumeRoleWithSAML(params) if err != nil { // Print the error, cast err to awserr.Error to get the Code and // Message from an error. fmt.Println(err.Error()) return } // Pretty-print the response data. fmt.Println(resp) }
func ExampleSTS_AssumeRoleWithWebIdentity() { svc := sts.New(session.New()) params := &sts.AssumeRoleWithWebIdentityInput{ RoleArn: aws.String("arnType"), // Required RoleSessionName: aws.String("roleSessionNameType"), // Required WebIdentityToken: aws.String("clientTokenType"), // Required DurationSeconds: aws.Int64(1), Policy: aws.String("sessionPolicyDocumentType"), ProviderId: aws.String("urlType"), } resp, err := svc.AssumeRoleWithWebIdentity(params) if err != nil { // Print the error, cast err to awserr.Error to get the Code and // Message from an error. fmt.Println(err.Error()) return } // Pretty-print the response data. fmt.Println(resp) }
func TestEC2RoleProvider(t *testing.T) { server := initTestServer("2014-12-16T01:51:37Z", false) defer server.Close() p := &ec2rolecreds.EC2RoleProvider{ Client: ec2metadata.New(session.New(), &aws.Config{Endpoint: aws.String(server.URL + "/latest")}), } creds, err := p.Retrieve() assert.Nil(t, err, "Expect no error, %v", err) assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match") assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match") assert.Equal(t, "token", creds.SessionToken, "Expect session token to match") }
func TestNestedMissingRequiredParameters(t *testing.T) { input := &StructShape{ RequiredList: []*ConditionalStructShape{{}}, RequiredMap: map[string]*ConditionalStructShape{ "key1": {Name: aws.String("Name")}, "key2": {}, }, RequiredBool: aws.Bool(true), OptionalStruct: &ConditionalStructShape{}, } req := testSvc.NewRequest(&request.Operation{}, input, nil) corehandlers.ValidateParametersHandler.Fn(req) require.Error(t, req.Error) assert.Equal(t, "InvalidParameter", req.Error.(awserr.Error).Code()) assert.Equal(t, "3 validation errors:\n- missing required parameter: RequiredList[0].Name\n- missing required parameter: RequiredMap[\"key2\"].Name\n- missing required parameter: OptionalStruct.Name", req.Error.(awserr.Error).Message()) }
func ExampleSTS_DecodeAuthorizationMessage() { svc := sts.New(session.New()) params := &sts.DecodeAuthorizationMessageInput{ EncodedMessage: aws.String("encodedMessageType"), // Required } resp, err := svc.DecodeAuthorizationMessage(params) if err != nil { // Print the error, cast err to awserr.Error to get the Code and // Message from an error. fmt.Println(err.Error()) return } // Pretty-print the response data. fmt.Println(resp) }
func BenchmarkEC3RoleProvider(b *testing.B) { server := initTestServer("2014-12-16T01:51:37Z", false) defer server.Close() p := &ec2rolecreds.EC2RoleProvider{ Client: ec2metadata.New(session.New(), &aws.Config{Endpoint: aws.String(server.URL + "/latest")}), } _, err := p.Retrieve() if err != nil { b.Fatal(err) } b.ResetTimer() for i := 0; i < b.N; i++ { if _, err := p.Retrieve(); err != nil { b.Fatal(err) } } }
func TestEC2RoleProviderFailAssume(t *testing.T) { server := initTestServer("2014-12-16T01:51:37Z", true) defer server.Close() p := &ec2rolecreds.EC2RoleProvider{ Client: ec2metadata.New(session.New(), &aws.Config{Endpoint: aws.String(server.URL + "/latest")}), } creds, err := p.Retrieve() assert.Error(t, err, "Expect error") e := err.(awserr.Error) assert.Equal(t, "ErrorCode", e.Code()) assert.Equal(t, "ErrorMsg", e.Message()) assert.Nil(t, e.OrigErr()) assert.Equal(t, "", creds.AccessKeyID, "Expect access key ID to match") assert.Equal(t, "", creds.SecretAccessKey, "Expect secret access key to match") assert.Equal(t, "", creds.SessionToken, "Expect session token to match") }
func TestIgnoreResignRequestWithValidCreds(t *testing.T) { svc := awstesting.NewClient(&aws.Config{ Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "SESSION"), Region: aws.String("us-west-2"), }) r := svc.NewRequest( &request.Operation{ Name: "BatchGetItem", HTTPMethod: "POST", HTTPPath: "/", }, nil, nil, ) Sign(r) sig := r.HTTPRequest.Header.Get("Authorization") Sign(r) assert.Equal(t, sig, r.HTTPRequest.Header.Get("Authorization")) }
func TestSkipPagination(t *testing.T) { client := s3.New(unit.Session) client.Handlers.Send.Clear() // mock sending client.Handlers.Unmarshal.Clear() client.Handlers.UnmarshalMeta.Clear() client.Handlers.ValidateResponse.Clear() client.Handlers.Unmarshal.PushBack(func(r *request.Request) { r.Data = &s3.HeadBucketOutput{} }) req, _ := client.HeadBucketRequest(&s3.HeadBucketInput{Bucket: aws.String("bucket")}) numPages, gotToEnd := 0, false req.EachPage(func(p interface{}, last bool) bool { numPages++ if last { gotToEnd = true } return true }) assert.Equal(t, 1, numPages) assert.True(t, gotToEnd) }
func ExampleSTS_AssumeRole() { svc := sts.New(session.New()) params := &sts.AssumeRoleInput{ RoleArn: aws.String("arnType"), // Required RoleSessionName: aws.String("roleSessionNameType"), // Required DurationSeconds: aws.Int64(1), ExternalId: aws.String("externalIdType"), Policy: aws.String("sessionPolicyDocumentType"), SerialNumber: aws.String("serialNumberType"), TokenCode: aws.String("tokenCodeType"), } resp, err := svc.AssumeRole(params) if err != nil { // Print the error, cast err to awserr.Error to get the Code and // Message from an error. fmt.Println(err.Error()) return } // Pretty-print the response data. fmt.Println(resp) }
}{ { err: awserr.New("InvalidParameter", "1 validation errors:\n- field too short, minimum length 5: StringField", nil), in: testInput{StringField: "abcd"}, }, { err: awserr.New("InvalidParameter", "2 validation errors:\n- field too short, minimum length 5: StringField\n- field too short, minimum length 3: ListField", nil), in: testInput{StringField: "abcd", ListField: []string{"a", "b"}}, }, { err: awserr.New("InvalidParameter", "3 validation errors:\n- field too short, minimum length 5: StringField\n- field too short, minimum length 3: ListField\n- field too short, minimum length 4: MapField", nil), in: testInput{StringField: "abcd", ListField: []string{"a", "b"}, MapField: map[string]string{"a": "a", "b": "b"}}, }, { err: awserr.New("InvalidParameter", "1 validation errors:\n- field too short, minimum length 2: PtrStrField", nil), in: testInput{StringField: "abcde", PtrStrField: aws.String("v")}, }, { err: nil, in: testInput{StringField: "abcde", PtrStrField: aws.String("value"), ListField: []string{"a", "b", "c"}, MapField: map[string]string{"a": "a", "b": "b", "c": "c", "d": "d"}}, }, } func TestValidateFieldMinParameter(t *testing.T) { for i, c := range testsFieldMin { req := testSvc.NewRequest(&request.Operation{}, &c.in, nil) corehandlers.ValidateParametersHandler.Fn(req) require.Equal(t, c.err, req.Error, "%d case failed", i) }
package sts_test import ( "testing" "github.com/stretchr/testify/assert" "github.com/aws/aws-sdk-go/awstesting/unit" "github.com/jtblin/aws-mock-metadata/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws" "github.com/jtblin/aws-mock-metadata/Godeps/_workspace/src/github.com/aws/aws-sdk-go/service/sts" ) var svc = sts.New(unit.Session, &aws.Config{ Region: aws.String("mock-region"), }) func TestUnsignedRequest_AssumeRoleWithSAML(t *testing.T) { req, _ := svc.AssumeRoleWithSAMLRequest(&sts.AssumeRoleWithSAMLInput{ PrincipalArn: aws.String("ARN01234567890123456789"), RoleArn: aws.String("ARN01234567890123456789"), SAMLAssertion: aws.String("ASSERT"), }) err := req.Sign() assert.NoError(t, err) assert.Equal(t, "", req.HTTPRequest.Header.Get("Authorization")) } func TestUnsignedRequest_AssumeRoleWithWebIdentity(t *testing.T) { req, _ := svc.AssumeRoleWithWebIdentityRequest(&sts.AssumeRoleWithWebIdentityInput{ RoleArn: aws.String("ARN01234567890123456789"),
} results := []string{} err := client.ListResourceRecordSetsPages(params, func(p *route53.ListResourceRecordSetsOutput, last bool) bool { results = append(results, *p.ResourceRecordSets[0].Name) return true }) assert.NoError(t, err) assert.Equal(t, []string{"", "second", ""}, idents) assert.Equal(t, []string{"first.example.com.", "second.example.com.", "third.example.com."}, results) } // Benchmarks var benchResps = []*dynamodb.ListTablesOutput{ {TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")}, {TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")}, {TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")}, {TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")}, {TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")}, {TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")}, {TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")}, {TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")}, {TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")}, {TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")}, {TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")}, {TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")}, {TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")}, {TableNames: []*string{aws.String("TABLE")}}, }
func TestPaginationNilToken(t *testing.T) { client := route53.New(unit.Session) reqNum := 0 resps := []*route53.ListResourceRecordSetsOutput{ { ResourceRecordSets: []*route53.ResourceRecordSet{ {Name: aws.String("first.example.com.")}, }, IsTruncated: aws.Bool(true), NextRecordName: aws.String("second.example.com."), NextRecordType: aws.String("MX"), NextRecordIdentifier: aws.String("second"), MaxItems: aws.String("1"), }, { ResourceRecordSets: []*route53.ResourceRecordSet{ {Name: aws.String("second.example.com.")}, }, IsTruncated: aws.Bool(true), NextRecordName: aws.String("third.example.com."), NextRecordType: aws.String("MX"), MaxItems: aws.String("1"), }, { ResourceRecordSets: []*route53.ResourceRecordSet{ {Name: aws.String("third.example.com.")}, }, IsTruncated: aws.Bool(false), MaxItems: aws.String("1"), }, } client.Handlers.Send.Clear() // mock sending client.Handlers.Unmarshal.Clear() client.Handlers.UnmarshalMeta.Clear() client.Handlers.ValidateResponse.Clear() idents := []string{} client.Handlers.Build.PushBack(func(r *request.Request) { p := r.Params.(*route53.ListResourceRecordSetsInput) idents = append(idents, aws.StringValue(p.StartRecordIdentifier)) }) client.Handlers.Unmarshal.PushBack(func(r *request.Request) { r.Data = resps[reqNum] reqNum++ }) params := &route53.ListResourceRecordSetsInput{ HostedZoneId: aws.String("id-zone"), } results := []string{} err := client.ListResourceRecordSetsPages(params, func(p *route53.ListResourceRecordSetsOutput, last bool) bool { results = append(results, *p.ResourceRecordSets[0].Name) return true }) assert.NoError(t, err) assert.Equal(t, []string{"", "second", ""}, idents) assert.Equal(t, []string{"first.example.com.", "second.example.com.", "third.example.com."}, results) }
// Use DynamoDB methods for simplicity func TestPaginationQueryPage(t *testing.T) { db := dynamodb.New(unit.Session) tokens, pages, numPages, gotToEnd := []map[string]*dynamodb.AttributeValue{}, []map[string]*dynamodb.AttributeValue{}, 0, false reqNum := 0 resps := []*dynamodb.QueryOutput{ { LastEvaluatedKey: map[string]*dynamodb.AttributeValue{"key": {S: aws.String("key1")}}, Count: aws.Int64(1), Items: []map[string]*dynamodb.AttributeValue{ map[string]*dynamodb.AttributeValue{ "key": {S: aws.String("key1")}, }, }, }, { LastEvaluatedKey: map[string]*dynamodb.AttributeValue{"key": {S: aws.String("key2")}}, Count: aws.Int64(1), Items: []map[string]*dynamodb.AttributeValue{ map[string]*dynamodb.AttributeValue{ "key": {S: aws.String("key2")}, }, }, }, { LastEvaluatedKey: map[string]*dynamodb.AttributeValue{}, Count: aws.Int64(1), Items: []map[string]*dynamodb.AttributeValue{ map[string]*dynamodb.AttributeValue{ "key": {S: aws.String("key3")}, }, }, }, } db.Handlers.Send.Clear() // mock sending db.Handlers.Unmarshal.Clear() db.Handlers.UnmarshalMeta.Clear() db.Handlers.ValidateResponse.Clear() db.Handlers.Build.PushBack(func(r *request.Request) { in := r.Params.(*dynamodb.QueryInput) if in == nil { tokens = append(tokens, nil) } else if len(in.ExclusiveStartKey) != 0 { tokens = append(tokens, in.ExclusiveStartKey) } }) db.Handlers.Unmarshal.PushBack(func(r *request.Request) { r.Data = resps[reqNum] reqNum++ }) params := &dynamodb.QueryInput{ Limit: aws.Int64(2), TableName: aws.String("tablename"), } err := db.QueryPages(params, func(p *dynamodb.QueryOutput, last bool) bool { numPages++ for _, item := range p.Items { pages = append(pages, item) } if last { if gotToEnd { assert.Fail(t, "last=true happened twice") } gotToEnd = true } return true }) assert.Nil(t, err) assert.Equal(t, []map[string]*dynamodb.AttributeValue{ map[string]*dynamodb.AttributeValue{"key": {S: aws.String("key1")}}, map[string]*dynamodb.AttributeValue{"key": {S: aws.String("key2")}}, }, tokens) assert.Equal(t, []map[string]*dynamodb.AttributeValue{ map[string]*dynamodb.AttributeValue{"key": {S: aws.String("key1")}}, map[string]*dynamodb.AttributeValue{"key": {S: aws.String("key2")}}, map[string]*dynamodb.AttributeValue{"key": {S: aws.String("key3")}}, }, pages) assert.Equal(t, 3, numPages) assert.True(t, gotToEnd) assert.Nil(t, params.ExclusiveStartKey) }