// Retrieve generates a new set of temporary credentials using STS. func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) { // Apply defaults where parameters are not set. if p.Client == nil { p.Client = sts.New(nil) } if p.RoleSessionName == "" { // Try to work out a role name that will hopefully end up unique. p.RoleSessionName = fmt.Sprintf("%d", time.Now().UTC().UnixNano()) } if p.Duration == 0 { // Expire as often as AWS permits. p.Duration = 15 * time.Minute } roleOutput, err := p.Client.AssumeRole(&sts.AssumeRoleInput{ DurationSeconds: aws.Int64(int64(p.Duration / time.Second)), RoleArn: aws.String(p.RoleARN), RoleSessionName: aws.String(p.RoleSessionName), ExternalId: p.ExternalID, }) if err != nil { return credentials.Value{}, err } // We will proactively generate new credentials before they expire. p.SetExpiration(*roleOutput.Credentials.Expiration, p.ExpiryWindow) return credentials.Value{ AccessKeyID: *roleOutput.Credentials.AccessKeyId, SecretAccessKey: *roleOutput.Credentials.SecretAccessKey, SessionToken: *roleOutput.Credentials.SessionToken, }, nil }
func TestPutBuffers(t *testing.T) { testCases := [][]message{ {{aws.String("twinkle"), []byte("twinkle")}}, {{aws.String("hey"), []byte("there")}, {aws.String("big"), []byte("fella")}}, } for _, testCase := range testCases { producer := producerWithStubClient(len(testCase) + 1) client := producer.client.(*StubClient) for _, input := range testCase { err := producer.Put(input.PartitionKey, input.Value) if err != nil { t.Fatalf("expected no Put errors. got '%s'", err) } } if client.puts != 0 { t.Errorf("expected 0 sent messages, got %d", client.puts) } if producer.current != len(testCase) { t.Errorf("expected %d buffered messages, found %d", len(testCase), len(producer.messages)) } } }
func TestPutRetriesFailedRecords(t *testing.T) { testCases := []struct { name string messages []message responses []clientResponse }{ { "one message retrying once", []message{{aws.String("twinkle"), []byte("twinkle")}}, []clientResponse{ {outputWithErrors("ProvisionedThroughputExceededException"), nil}, }, }, { "one message retrying twice", []message{{aws.String("twinkle"), []byte("twinkle")}}, []clientResponse{ {outputWithErrors("ProvisionedThroughputExceededException"), nil}, {outputWithErrors("ProvisionedThroughputExceededException"), nil}, }, }, { "two messages with one retrying once", []message{{aws.String("hey"), []byte("there")}, {aws.String("big"), []byte("fella")}}, []clientResponse{ {outputWithErrors("", "ProvisionedThroughputExceededException"), nil}, }, }, } for _, testCase := range testCases { actualRetries := 0 producer := producerRespondingWith(MaxSendSize, testCase.responses...) producer.Throttle = func() Throttle { actualRetries++ return &noOpThrottle{} } client := producer.client.(*StubClient) for _, m := range testCase.messages { if err := producer.Put(m.PartitionKey, m.Value); err != nil { t.Fatalf("unexpected producer error! %s", err) } } producer.Flush() expectedRetries := len(testCase.responses) if actualRetries != expectedRetries { t.Errorf("expected %d retries, got %d", expectedRetries, actualRetries) } assertSentMessages(t, testCase.name, testCase.messages, client.sent) } }
func shardsToAws(shards ...shard) []*kinesis.Shard { var awsShards []*kinesis.Shard for _, s := range shards { awsShard := &kinesis.Shard{ShardId: aws.String(s.id)} if s.parentOne != "" { awsShard.ParentShardId = aws.String(s.parentOne) } if s.parentTwo != "" { awsShard.AdjacentParentShardId = aws.String(s.parentTwo) } awsShards = append(awsShards, awsShard) } return awsShards }
// A response where every record in the given input is a success. func successfulResponse(input *kinesis.PutRecordsInput) (*kinesis.PutRecordsOutput, error) { records := make([]*kinesis.PutRecordsResultEntry, len(input.Records)) for i := range input.Records { records[i] = &kinesis.PutRecordsResultEntry{ SequenceNumber: aws.String("sequence_number"), ShardId: aws.String("an_shard"), } } output := &kinesis.PutRecordsOutput{ FailedRecordCount: aws.Int64(0), Records: records, } return output, nil }
func (p *Producer) send() error { if len(p.messages) == 0 { return nil } defer p.reset() stream, messages := aws.String(p.StreamName), p.messages[0:p.current] for { res, err := p.client.PutRecords(putRecordsInput(stream, messages)) if err != nil { return err } if *res.FailedRecordCount == 0 { if p.Debug { log.Printf("Put %d message(s).", len(res.Records)) } return nil } messages = failedMessages(messages, res.Records) if p.Debug { log.Printf("Put failed for %d message(s). Backing off and trying again.", *res.FailedRecordCount) } p.Throttle().Await() } return nil }
func consumerWith(descriptions [][]shard, data map[string][]string, processor Processor) *Consumer { return &Consumer{ stream: aws.String(defaultStream), client: &StubClient{describe: descriptions, records: data}, complete: make(chan string), processor: processor, waiterFunc: func() waiter { return &stubWaiter{} }, } }
func TestInvalidPut(t *testing.T) { invalid := []struct { name string key *string value []byte expected error }{ {"key too long", aws.String(randomString(257)), []byte{0x6c, 0x6f, 0x6c}, PartitionKeyTooLong}, {"invalid unicode", aws.String(string([]byte{0xc1, 0xbf})), []byte{0x6c, 0x6f, 0x6c}, InvalidUnicode}, {"empty value", aws.String(randomString(123)), []byte{}, EmptyValue}, } producer := producerWithStubClient(123) for _, tc := range invalid { if err := producer.Put(tc.key, tc.value); !errContains(err, tc.expected) { t.Errorf("expected %s to error with '%s'. got '%+v'", tc.name, tc.expected, err) } } }
// Start a consumer at the given Stream's LATEST and process each shard with // processor. // // Each shard will be processed in an individual goroutine. func Tail(stream string, debug bool, processor Processor) error { c := &Consumer{ stream: aws.String(stream), client: kinesis.New(nil), processor: processor, debug: debug, complete: make(chan string), waiterFunc: func() waiter { return &realWaiter{} }, } return c.tail() }
func TestPutSends(t *testing.T) { testCases := [][]message{ {{aws.String("twinkle"), []byte("twinkle")}}, {{aws.String("hey"), []byte("there")}, {aws.String("big"), []byte("fella")}}, } for _, testCase := range testCases { producer := producerWithStubClient(len(testCase)) client := producer.client.(*StubClient) for _, input := range testCase { err := producer.Put(input.PartitionKey, input.Value) if err != nil { t.Fatalf("expected no Put errors. got '%s'", err) } } if client.puts != 1 { t.Errorf("expected a single send to Kinesis") return } assertSentMessages(t, fmt.Sprintf("%d successful records", len(testCase)), testCase, client.sent) } }
func (c *Consumer) startShardConsumer(shard string, iterType *string, processor Processor) { s := &shardConsumer{ client: c.client, stream: c.stream, shard: aws.String(shard), debug: c.debug, processor: processor, waiter: c.waiterFunc(), complete: c.complete, } go func() { s.init(iterType) s.consume() }() }
// Return an output response with the given error codes. Blank strings imply // that a message was not an error. func outputWithErrors(codes ...string) *kinesis.PutRecordsOutput { resultEntries := make([]*kinesis.PutRecordsResultEntry, len(codes)) errorCount := 0 for i, code := range codes { resultEntries[i] = &kinesis.PutRecordsResultEntry{ ErrorCode: aws.String(code), } if code != "" { errorCount++ } } return &kinesis.PutRecordsOutput{ FailedRecordCount: aws.Int64(int64(errorCount)), Records: resultEntries, } }
// Send the given string to Kinesis. The first 256 bytes of the string will be // used as the partition key. message must not be a valid Unicode string, and // must be non-empty. func (p *Producer) PutString(message string) error { return p.Put(aws.String(message[:intMin(len(message), 256)]), []byte(message)) }
func (c *Consumer) tail() error { shards, err := c.listShards() if err != nil { return err } go c.monitor() for _, id := range withNoChildren(shards) { c.startShardConsumer(id, LATEST, c.processor) } return nil } var LATEST = aws.String(kinesis.ShardIteratorTypeLatest) var TRIM_HORIZON = aws.String(kinesis.ShardIteratorTypeTrimHorizon) func (c *Consumer) startShardConsumer(shard string, iterType *string, processor Processor) { s := &shardConsumer{ client: c.client, stream: c.stream, shard: aws.String(shard), debug: c.debug, processor: processor, waiter: c.waiterFunc(), complete: c.complete, } go func() {