示例#1
0
func TestLoadConfig(t *testing.T) {
	aws := AWSConfig{
		Accesskey: "CF2DC307CC89F49F68F365235AA54BAB2DDD02DA",
		SecretKey: "ca485e5b709eae0ceacd68b61cdb28119f13942c71bbb68066c8a3cb45185a39",
	}

	testName := utils.GetAppName() + "-test"
	tempDir, _ := ioutil.TempDir("", testName)
	out := OutConfig{
		Root: tempDir,
		File: true,
		Bom:  true,
	}

	log := LogConfig{
		Root:    tempDir,
		Verbose: true,
		JSON:    true,
	}

	rds := RDSConfig{
		MultiAz: true,
		DBId:    utils.GetFormatedDBDisplayName(testName),
		Region:  "us-west-2",
		User:    "******",
		Pass:    "******",
		Type:    "db.m3.medium",
	}
	rdsMap := map[string]RDSConfig{
		"default": rds,
	}

	config := &Config{
		Aws: aws,
		Out: out,
		Rds: rdsMap,
		Log: log,
	}
	tempFile, err := ioutil.TempFile(tempDir, utils.GetAppName()+"-test")
	if err != nil {
		t.Errorf("failed to create the temp file: %s", err.Error())
	}
	if err := toml.NewEncoder(tempFile).Encode(config); err != nil {
		t.Errorf("failed to create the toml file: %s", err.Error())
	}
	tempFile.Sync()
	tempFile.Close()
	defer os.RemoveAll(tempDir)

	conf, err := LoadConfig(tempFile.Name())
	if err != nil {
		t.Errorf("config file load error: %s", err.Error())
	}
	if !reflect.DeepEqual(config, conf) {
		t.Errorf("config data not match: %+v/%+v", config, conf)
	}
}
示例#2
0
// return "httptest.Server" need call close !!
// and Command into OutConfig.Root need call remove !!
func getTestClient(code int, body string) (*httptest.Server, *Command) {
	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		w.WriteHeader(code)
		w.Header().Set("Content-Type", "application/xml")
		fmt.Fprintln(w, body)
	}))

	transport := &http.Transport{
		Proxy: func(req *http.Request) (*url.URL, error) {
			return url.Parse(server.URL)
		},
	}

	httpClient := &http.Client{Transport: transport}

	// Override endpoints
	testRegion := "rds-try-test-1"
	awsConf := aws.NewConfig()
	awsConf = awsConf.WithCredentials(credentials.NewStaticCredentials("awsAccesskey1", "awsSecretKey2", ""))
	awsConf = awsConf.WithRegion(testRegion)
	awsConf = awsConf.WithEndpoint(server.URL)
	awsConf = awsConf.WithHTTPClient(httpClient)

	awsRds := rds.New(awsConf)

	testName := utils.GetAppName() + "-test"
	tempDir, _ := ioutil.TempDir("", testName)
	defer os.RemoveAll(tempDir)

	out := config.OutConfig{
		Root: tempDir,
		File: true,
		Bom:  true,
	}

	rds := config.RDSConfig{
		MultiAz: false,
		DBId:    utils.GetFormatedDBDisplayName(testName),
		Region:  testRegion,
		User:    "******",
		Pass:    "******",
		Type:    "db.m3.medium",
	}

	cmdTest := &Command{
		OutConfig: out,
		RDSConfig: rds,
		RDSClient: awsRds,
		ARNPrefix: "arn:aws:rds:" + testRegion + ":" + "123456789" + ":",
	}

	return server, cmdTest
}
示例#3
0
// need to run the caller always "defer os.RemoveAll(temp_dir)"
func getTestConfig() (*config.Config, string, error) {
	aws := config.AWSConfig{
		Accesskey: "CF2DC307CC89F49F68F365235AA54BAB2DDD02DA",
		SecretKey: "ca485e5b709eae0ceacd68b61cdb28119f13942c71bbb68066c8a3cb45185a39",
	}

	testName := utils.GetAppName() + "-test"
	tempDir, _ := ioutil.TempDir("", testName)
	out := config.OutConfig{
		Root: tempDir,
		File: true,
		Bom:  true,
	}

	log := config.LogConfig{
		Root:    tempDir,
		Verbose: true,
		JSON:    true,
	}

	rds := config.RDSConfig{
		MultiAz: true,
		DBId:    utils.GetFormatedDBDisplayName(testName),
		Region:  "us-west-2",
		User:    "******",
		Pass:    "******",
		Type:    "db.m3.medium",
	}
	rdsMap := map[string]config.RDSConfig{
		"default2": rds,
	}

	config := &config.Config{
		Aws: aws,
		Out: out,
		Rds: rdsMap,
		Log: log,
	}
	tempFile, err := ioutil.TempFile(tempDir, utils.GetAppName()+"-test")
	if err != nil {
		return nil, tempDir, err
	}
	if err := toml.NewEncoder(tempFile).Encode(config); err != nil {
		return nil, tempDir, err
	}
	tempFile.Sync()
	tempFile.Close()

	os.Args = []string{utils.GetAppName(), "-c=" + tempFile.Name(), "-n=default", "ls"}
	flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError)

	return config, tempDir, nil
}
示例#4
0
// CreateDBSnapshot is create aws rds db snap shot
func (c *Command) CreateDBSnapshot(dbIdentifier string) (*rds.DBSnapshot, error) {
	snapshotID := utils.GetFormatedDBDisplayName(dbIdentifier)
	input := &rds.CreateDBSnapshotInput{
		DBInstanceIdentifier: &dbIdentifier,
		DBSnapshotIdentifier: &snapshotID,
		Tags:                 getSpecifyTags(), // It must always be set to not forget
	}

	output, err := c.RDSClient.CreateDBSnapshot(input)

	if err != nil {
		log.Errorf("%s", err.Error())
		return nil, err
	}

	return output.DBSnapshot, err
}
示例#5
0
func (c *EsCommand) runDetails(f *flag.FlagSet) error {
	// load query
	queryFile := query.GetDefaultPath()
	if c.OptQuery != "" {
		queryFile = c.OptQuery
	}
	queries, err := query.LoadQuery(queryFile)
	if err != nil {
		return err
	}
	log.Debugf("%+v", queries)

	// option create snapshot
	// or
	// get latest db snap shot
	var snapShot *rds.DBSnapshot
	if c.OptSnap {
		snapShot, err = c.CreateDBSnapshot(c.RDSConfig.DBId)
		if err != nil {
			return err
		}

		// wait for available
		waitChan := c.WaitForStatusAvailable(snapShot)
		if !<-waitChan {
			return ErrDBInstancetTimeOut
		}
	} else {
		snapShot, err = c.DescribeLatestDBSnapshot(c.RDSConfig.DBId)
		if err != nil {
			return err
		}
	}

	// get now active db info
	// to-do: can not run if the running instance does not exist
	actDB, err := c.DescribeDBInstance(c.RDSConfig.DBId)
	if err != nil {
		return err
	}

	// "DBInstanceClass" is determined in the following order
	// 1. argument value
	// 2. config file type
	// 3. running DB Instance Class
	restType := *actDB.DBInstanceClass
	if c.RDSConfig.Type != "" {
		restType = c.OptType
	}
	if c.OptType != "" {
		restType = c.OptType
	}
	restName := utils.GetFormatedDBDisplayName(c.RDSConfig.DBId)
	restArgs := &RestoreDBInstanceFromDBSnapshotArgs{
		DBInstanceClass: restType,
		DBIdentifier:    restName,
		MultiAZ:         c.RDSConfig.MultiAz,
		Snapshot:        snapShot,
		Instance:        actDB,
	}
	restDB, err := c.RestoreDBInstanceFromDBSnapshot(restArgs)
	if err != nil {
		return err
	}
	log.Infof("%+v", *restArgs)

	// wait for available
	waitChan := c.WaitForStatusAvailable(restDB)
	if !<-waitChan {
		return ErrDBInstancetTimeOut
	}

	// DB is restored in the default state
	// So, I do modify
	restDB, err = c.ModifyDBInstance(restName, actDB)
	if err != nil {
		return err
	}

	// wait for available
	waitChan = c.WaitForStatusAvailable(restDB)
	if !<-waitChan {
		return ErrDBInstancetTimeOut
	}

	// enable the setting by performing reboot
	restDB, err = c.RebootDBInstance(restName)
	if err != nil {
		return err
	}

	// wait for available
	waitChan = c.WaitForStatusAvailable(restDB)
	if !<-waitChan {
		return ErrDBInstancetTimeOut
	}

	// get db info
	restDB, err = c.DescribeDBInstance(restName)
	if err != nil {
		return err
	}

	// setting check
	var count = 1
	for c.CheckPendingStatus(restDB) {
		// max count
		if count > 6 {
			return ErrDBInstancetTimeOut
		}

		count++
		log.Infof("restart %d times! because change has not been applied", count)

		// once again reboot
		restDB, err = c.RebootDBInstance(restName)
		if err != nil {
			return err
		}

		// wait for available
		waitChan = c.WaitForStatusAvailable(restDB)
		if !<-waitChan {
			return ErrDBInstancetTimeOut
		}

		// get db info
		restDB, err = c.DescribeDBInstance(restName)
		if err != nil {
			return err
		}
	}

	// run queries
	times, err := c.ExecuteSQL(
		&ExecuteSQLArgs{
			Engine:   *restDB.Engine,
			Endpoint: restDB.Endpoint,
			Queries:  queries.Query,
		})
	if err != nil {
		return err
	}

	// show total time
	var total float64
	totalText := "\nruntime result:\n"
	for i, time := range times {
		total += time.Seconds()
		totalText += fmt.Sprintf("  query name   : %s\n  query runtime: %s\n\n", queries.Query[i].Name, time.String())
	}

	hour := int(total) / 3600
	minute := (int(total) - hour*3600) / 60
	second := total - float64(hour*3600) - float64(minute*60)

	totalText += "--------------------------------\n"
	timeText := fmt.Sprintf("  total runtime: %.3f sec\n", second)
	if minute > 0 {
		timeText = fmt.Sprintf("  total runtime: %d m %.3f sec\n", minute, second)
	}
	if hour > 0 {
		timeText = fmt.Sprintf("  total runtime: %d h %d m %.3f sec\n", hour, minute, second)
	}
	totalText += timeText
	fmt.Println(totalText)

	return nil
}