func reportError(desc string, err error) { // print to stderr and sns notify if required
	if cfg.Sns.FailureNotifications && len(cfg.Sns.Topic) > 0 && len(cfg.Aws.Accesskey) > 0 && len(cfg.Aws.Secretkey) > 0 {
		newSns, snsErr := sns.New(*aws.NewAuth(cfg.Aws.Accesskey, cfg.Aws.Secretkey, "", time.Now()), aws.Regions[cfg.Aws.Region])
		if snsErr != nil {
			log.Println(fmt.Sprintf("SNS error: %#v during report of error writing to kafka: %#v", snsErr, err))
		}
		newSns.Publish(&sns.PublishOptions{fmt.Sprintf("%s: %#v", desc, err), "", "[redshift-tracking-copy-from-s3] ERROR Notification", cfg.Sns.Topic, ""})
	}
	fmt.Printf("%s: %s\n", desc, err)
	panic(err)
}
func main() {
	flag.Parse() // Read argv

	if shouldOutputVersion {
		fmt.Printf("redshift-tracking-copy-from-s3 %s\n", VERSION)
		os.Exit(0)
	}

	// Read config file
	parseConfigfile()

	// ----------------------------- Startup goroutine for each Bucket/Prefix/Table & Repeat migration check per table -----------------------------

	done := make(chan bool, len(cfg.Redshift.Tables))
	for i, _ := range cfg.Redshift.Tables {
		quitSignal := make(chan os.Signal, 1)
		signal.Notify(quitSignal, os.Interrupt)

		go func(currentTable string, currentBucket string, currentPrefix string) {
			quitReceived := false

			go func() {
				<-quitSignal
				if cfg.Default.Debug {
					fmt.Printf("Quit signal received on %s watcher. Going down...\n", currentTable)
				}
				quitReceived = true
			}()

			db, err := sql.Open("postgres", fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", cfg.Redshift.Host, cfg.Redshift.Port, cfg.Redshift.User, cfg.Redshift.Password, cfg.Redshift.Database))
			if err != nil {
				reportError("Couldn't connect to redshift database: ", err)
			}
			rows, err := db.Query(fmt.Sprintf("select COLUMN_NAME, DATA_TYPE from INFORMATION_SCHEMA.COLUMNS where table_name = '%s' limit 1000", currentTable))
			if err != nil {
				reportError("Couldn't execute statement for INFORMATION_SCHEMA.COLUMNS: ", err)
			}
			if cfg.Default.Debug {
				fmt.Println("Looking for table, columns will display below.")
			}
			anyRows := false
			for rows.Next() {
				var column_name string
				var data_type string
				err = rows.Scan(&column_name, &data_type)
				if err != nil {
					reportError("Couldn't scan row for table: ", err)
				}
				if cfg.Default.Debug {
					fmt.Printf("   %s, %s\n", column_name, data_type)
				}
				anyRows = true
			}

			if !anyRows {
				reportError("Table had no columns: ", err)
			} else {
				if cfg.Default.Debug {
					fmt.Println("Table found, will not migrate")
				}
			}

			// ----------------------------- Take a look at STL_FILE_SCAN on this Table to see if any files have already been imported -----------------------------

			for !quitReceived {
				if cfg.Default.Debug {
					fmt.Printf("Re-polling with %s watcher.\n", currentTable)
				}
				loadedFiles := map[string]bool{}

				rows, err = db.Query(fmt.Sprintf("select * from STL_FILE_SCAN"))
				if err != nil {
					reportError("Couldn't execute STL_FILE_SCAN: ", err)
				}
				anyRows = false
				for rows.Next() {
					var (
						userid   int
						query    int
						slice    int
						name     string
						lines    int64
						bytes    int64
						loadtime int64
						curtime  time.Time
					)
					err = rows.Scan(&userid, &query, &slice, &name, &lines, &bytes, &loadtime, &curtime)
					if err != nil {
						reportError("Couldn't scan row for STL_FILE_SCAN: ", err)
					}
					if cfg.Default.Debug {
						fmt.Printf("  Already loaded: %d|%d|%d|%s|%d|%d|%d|%s\n", userid, query, slice, name, lines, bytes, loadtime, curtime)
					}
					loadedFiles[strings.TrimPrefix(strings.TrimSpace(name), fmt.Sprintf("s3://%s/", currentBucket))] = true
					anyRows = true
				}

				// ----------------------------- If not: run generic COPY for this Bucket/Prefix/Table -----------------------------
				if !anyRows {
					copyStmt := defaultCopyStmt(&currentTable, &currentBucket, &currentPrefix)
					if cfg.Default.Debug {
						fmt.Printf("No records found in STL_FILE_SCAN, running `%s`\n", copyStmt)
					}
					_, err = db.Exec(copyStmt)
					if err != nil {
						reportError("Couldn't execute default copy statement: ", err)
					}
				} else {

					// ----------------------------- If yes: diff STL_FILE_SCAN with S3 bucket files list, COPY and missing files into this Table -----------------------------
					if cfg.Default.Debug {
						fmt.Printf("Records found, have to do manual copies from now on.\n")
					}
					s3bucket := s3.New(*aws.NewAuth(cfg.Aws.Accesskey, cfg.Aws.Secretkey, "", time.Now()), aws.Regions[cfg.Aws.Region]).Bucket(currentBucket)

					// list all missing files and copy in the ones that are missing
					nonLoadedFiles := []string{}
					keyMarker := ""
					moreResults := true
					for moreResults {
						if cfg.Default.Debug {
							fmt.Printf("Checking s3 bucket %s.\n", currentBucket)
						}
						results, err := s3bucket.List(currentPrefix, "", keyMarker, 0)
						if err != nil {
							reportError("Couldn't list default s3 bucket: ", err)
						}
						if cfg.Default.Debug {
							fmt.Printf("s3bucket.List returned %#v.\n", results)
						}
						if len(results.Contents) == 0 {
							break
						} // empty request, assume we found every file
						for _, s3obj := range results.Contents {
							if cfg.Default.Debug {
								fmt.Printf("Checking whether or not %s was preloaded.\n", strings.TrimSpace(s3obj.Key))
							}
							if !loadedFiles[strings.TrimSpace(s3obj.Key)] {
								nonLoadedFiles = append(nonLoadedFiles, s3obj.Key)
							}
						}
						keyMarker = results.Contents[len(results.Contents)-1].Key
						moreResults = results.IsTruncated
					}

					if cfg.Default.Debug {
						fmt.Printf("Haven't ever loaded %#v.\n", nonLoadedFiles)
					}
					for _, s3key := range nonLoadedFiles {
						copyStmt := defaultCopyStmt(&currentTable, &currentBucket, &s3key)
						if cfg.Default.Debug {
							fmt.Printf("  Copying `%s`\n", copyStmt)
						}
						_, err = db.Exec(copyStmt)
						if err != nil {
							reportError("Couldn't execute default copy statement: ", err)
						}
					}

				}

				time.Sleep(time.Duration(cfg.Default.Pollsleepinseconds*1000) * time.Millisecond)
			}

			done <- true
		}(cfg.Redshift.Tables[i], cfg.S3.Buckets[i], cfg.S3.Prefixes[i])

	}

	<-done // wait until the last iteration finishes before returning
}
示例#3
0
文件: aws.go 项目: hsbt/stretcher
func LoadAWSCredentials(profileName string) (aws.Auth, aws.Region, error) {
	if profileName == "" {
		if p := os.Getenv("AWS_DEFAULT_PROFILE"); p != "" {
			profileName = p
		} else {
			profileName = AWSDefaultProfileName
		}
	}

	var awsAuth aws.Auth
	var awsRegion aws.Region

	// load from File (~/.aws/config, ~/.aws/credentials)
	configFile := os.Getenv("AWS_CONFIG_FILE")
	if configFile == "" {
		if dir, err := homedir.Dir(); err == nil {
			configFile = filepath.Join(dir, ".aws", "config")
		}
	}

	dir, _ := filepath.Split(configFile)
	_profile := AWSDefaultProfileName
	if profileName != AWSDefaultProfileName {
		_profile = "profile " + profileName
	}
	auth, region, _ := loadAWSConfigFile(configFile, _profile)
	if isValidAuth(auth) {
		awsAuth = auth
	}
	if isValidRegion(region) {
		awsRegion = region
	}

	credFile := filepath.Join(dir, "credentials")
	auth, region, _ = loadAWSConfigFile(credFile, profileName)
	if isValidAuth(auth) {
		awsAuth = auth
	}
	if isValidRegion(region) {
		awsRegion = region
	}

	// Override by environment valiable
	if region := os.Getenv("AWS_DEFAULT_REGION"); region != "" {
		awsRegion = aws.GetRegion(region)
	}
	if os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "" {
		if auth, _ := aws.EnvAuth(); isValidAuth(auth) {
			awsAuth = auth
		}
	}
	if isValidAuth(awsAuth) && isValidRegion(awsRegion) {
		return awsAuth, awsRegion, nil
	}

	// Otherwise, use IAM Role
	cred, err := aws.GetInstanceCredentials()
	if err == nil {
		exptdate, err := time.Parse("2006-01-02T15:04:05Z", cred.Expiration)
		if err == nil {
			auth := aws.NewAuth(cred.AccessKeyId, cred.SecretAccessKey, cred.Token, exptdate)
			awsAuth = *auth
		}
	}
	if isValidAuth(awsAuth) && isValidRegion(awsRegion) {
		return awsAuth, awsRegion, nil
	}

	return awsAuth, awsRegion, errors.New("cannot detect valid credentials or region")
}