예제 #1
0
파일: util.go 프로젝트: qjpcpu/sexy-ssh
func ParallelRun(config map[string]interface{}, raw_host_arr []string, start, end int, tmpdir string) error {
	host_arr := raw_host_arr[start:end]
	user, _ := config["User"].(string)
	pwd, _ := config["Password"].(string)
	keyfile, _ := config["Keyfile"].(string)
	cmd, _ := config["Cmd"].(string)
	args, _ := config["Args"].(string)
	timeout, _ := config["Timeout"].(int)
	cmd = format_cmd(cmd, args)
	printer, _ := config["Output"].(io.Writer)
	err_printer, _ := config["Errout"].(io.Writer)

	// Create master, the master is used to manage go routines
	mgr, _ := job.NewManager()
	// Setup tmp directory for tmp files
	dir := fmt.Sprintf("%s/.s3h.%d", tmpdir, time.Now().Nanosecond())
	if err := os.Mkdir(dir, os.ModeDir|os.ModePerm); err != nil {
		return err
	}

	// Listen interrupt and kill signal, clear tmp files before exit.
	intqueue := make(chan os.Signal, 1)
	signal.Notify(intqueue, os.Interrupt, os.Kill)
	// If got interrupt or kill signal, delete tmp directory first, then exit with 1
	go func() {
		<-intqueue
		os.RemoveAll(dir)
		os.Exit(1)
	}()
	// If the complete all the tasks normlly, stop listenning signals and remove tmp directory
	defer func() {
		signal.Stop(intqueue)
		os.RemoveAll(dir)
	}()

	// Create tmp file for every host, then executes.
	var tmpfiles []*os.File
	for _, h := range host_arr {
		file, _ := os.Create(fmt.Sprintf("%s/%s", dir, h))
		err_file, _ := os.Create(fmt.Sprintf("%s/%s.err", dir, h))
		tmpfiles = append(tmpfiles, file, err_file)
		s3h := sssh.NewS3h(h, user, pwd, keyfile, cmd, file, err_file, mgr)
		s3h.Timeout = timeout
		go s3h.Work()
	}

	// show realtime view for each host
	var dc *dircat.DirCat
	if terminal.IsTerminal(1) {
		wlist := []string{}
		for _, h := range host_arr {
			wlist = append(wlist, fmt.Sprintf("%s/%s", dir, h))
		}
		dc, _ = dircat.Init(wlist...)
		go dc.Start()
	}
	// When a host is ready and request for continue, the master would echo CONTINUE for response to allow host to run
	size := len(host_arr)
	for {
		data, _ := mgr.Receive(-1)
		info, _ := data.(map[string]interface{})
		if info["BODY"].(string) == "BEGIN" {
			mgr.Send(info["FROM"].(string), map[string]interface{}{"FROM": job.MASTER_ID, "BODY": "CONTINUE"})
		} else if info["BODY"].(string) == "END" {
			// If master gets every hosts' END message, then it stop waiting.
			size -= 1
			if size == 0 {
				break
			}
		}
	}
	if terminal.IsTerminal(1) {
		dc.Stop()
	}
	// close tmp files
	for _, f := range tmpfiles {
		f.Close()
	}
	// Merge all the hosts' output to the output file
	for _, h := range host_arr {
		report(os.Stderr, "", h, true)
		// copy err output first
		err_fn := fmt.Sprintf("%s/%s.err", dir, h)
		err_src, _ := os.Open(err_fn)
		io.Copy(err_printer, err_src)
		err_src.Close()
		// copy output then
		fn := fmt.Sprintf("%s/%s", dir, h)
		src, _ := os.Open(fn)
		io.Copy(printer, src)
		src.Close()
		// remove tmp file
		os.Remove(err_fn)
		os.Remove(fn)
	}
	return nil
}
예제 #2
0
파일: main.go 프로젝트: qjpcpu/sexy-ssh
func main() {
	options := SeshFlags{
		Debug:          false,
		Pause:          false,
		Tmpdir:         ".",
		Parallel:       false,
		ParallelDegree: 0,
	}
	goptions.ParseAndFail(&options)

	//timeout
	if options.Timeout < 1 {
		options.Timeout = 5
	}
	// get hosts
	var host_arr []string
	if options.Hostfile != "" {
		if buf, err := ioutil.ReadFile(options.Hostfile); err != nil {
			fmt.Fprintln(os.Stderr, "\033[31mFailed to read host from file!\033[0m")
			return
		} else {
			hoststr := string(buf)
			host_arr = parseHostsFromString(hoststr)
		}
	} else if options.Hostlist != "" {
		host_arr = parseHostsFromString(options.Hostlist)
	} else {
		if terminal.IsTerminal(0) {
			fmt.Fprintln(os.Stderr, "\033[33mPlease input hosts, seperated by LINE SEPERATOR, press Ctrl+D to finish input:\033[0m")
		}
		buf, _ := ioutil.ReadAll(os.Stdin)
		host_arr = parseHostsFromString(string(buf))
	}

	// get user
	rc, err := util.Gets3hrc()
	rc_sec := "default"
	if options.User == "" {
		if err == nil {
			options.User = rc[rc_sec]["user"]
		}
		if options.User == "" {
			options.User = os.Getenv("USER")
		}
	} else {
		_, ok := rc[options.User]
		if ok && rc[options.User]["user"] == options.User {
			rc_sec = options.User
		}
	}
	// get password
	if options.Password == "" && err == nil && options.User == rc[rc_sec]["user"] && rc[rc_sec]["password"] != "" {
		options.Password = rc[rc_sec]["password"]
	}
	// get  key file
	if options.Keyfile == "" {
		if err == nil {
			options.Keyfile = rc[rc_sec]["keyfile"]
		}
		if options.Keyfile == "" {
			options.Keyfile = os.Getenv("HOME") + "/.ssh/id_rsa"
		}
		if _, err := os.Stat(options.Keyfile); os.IsNotExist(err) {
			if options.Password == "" {
				if os.Getenv("SSH_AUTH_SOCK") == "" {
					fmt.Fprintln(os.Stderr, "\033[31mKey file "+options.Keyfile+" not found!\033[0m")
					return
				}
				if _, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err != nil {
					fmt.Fprintln(os.Stderr, "\033[31mKey file "+options.Keyfile+" not found!\033[0m")
					return
				}
			} else {
				options.Keyfile = ""
			}
		}
	}

	// Scp file
	if options.Sscp.Src != "" && options.Sscp.Destdir != "" {
		config := map[string]interface{}{
			"User":     options.User,
			"Password": options.Password,
			"Keyfile":  options.Keyfile,
			"Source":   options.Sscp.Src,
			"Destdir":  options.Sscp.Destdir,
			"Timeout":  options.Timeout,
		}
		if err := util.ScpRun(config, host_arr); err != nil {
			fmt.Fprintf(os.Stderr, "\033[31mCopy faild! %v\033[0m\n", err)
		} else {
			fmt.Fprintln(os.Stderr, "\033[32mFinished!\033[0m")
		}
		return
	}

	//check command
	if len(options.Cmd) == 0 && len(options.Cmdfile) == 0 {
		fmt.Fprintln(os.Stderr, "\033[31mPlese specify command you want execute.\033[0m")
		return
	}
	// parse command template
	cmd := ""
	if len(options.Cmdfile) > 0 {
		for _, cf := range options.Cmdfile {
			if _, err := os.Stat(cf); os.IsNotExist(err) {
				fmt.Fprintln(os.Stderr, "\033[31mCommand file "+cf+" not found!\033[0m")
				return
			}
		}
		if o, err := templ.ParseFromFiles(options.Cmdfile, parseData(options.Data)); err != nil {
			fmt.Fprintf(os.Stderr, "\033[31mParse command file failed!\033[0m\n%v\n", err)
			return
		} else {
			cmd = o
		}
	} else {
		// join commands
		for _, v := range options.Cmd {
			cmd = cmd + v + " "
		}
		if o, err := templ.ParseFromString(cmd, parseData(options.Data)); err != nil {
			fmt.Fprintf(os.Stderr, "\033[31mParse command failed!\033[0m\n%v\n", err)
			return
		} else {
			cmd = o
		}
	}
	if _, err := os.Stat(options.Tmpdir); os.IsNotExist(err) && options.Parallel {
		fmt.Fprintln(os.Stderr, "\033[31mTemporary directory "+options.Tmpdir+" is not exist!\033[0m")
		return
	}

	// Begin to run
	config := map[string]interface{}{
		"User":     options.User,
		"Password": options.Password,
		"Keyfile":  options.Keyfile,
		"Cmd":      cmd,
		"Args":     options.Arguments,
		"Output":   os.Stdout,
		"Errout":   os.Stderr,
		"Timeout":  options.Timeout,
	}
	if options.Debug {
		printDebugInfo(options, host_arr, cmd)
		return
	}
	host_offset := 0
	if options.Pause {
		util.SerialRun(config, host_arr, host_offset, 1)
		fmt.Fprintf(os.Stderr, "The task on \033[33m%s\033[0m has done.\nPress any key to auto login \033[33m%s\033[0m to have a check...", host_arr[0], host_arr[0])
		reader := bufio.NewReader(os.Stdin)
		reader.ReadString('\n')
		util.Interact(config, host_arr[0])
		fmt.Fprintf(os.Stderr, "\n\033[32mCheck completed! Press any key to acomplish the left tasks.\033[0m")
		reader = bufio.NewReader(os.Stdin)
		reader.ReadString('\n')
		host_offset = 1
	}
	if options.Parallel {
		fmt.Fprintln(os.Stderr, util.GirlSay("  Please wait me for a moment, Baby!  "))
		end := len(host_arr)
		if options.ParallelDegree < 1 || options.ParallelDegree > (end-host_offset) {
			options.ParallelDegree = end - host_offset
		}
		for {
			to := host_offset + options.ParallelDegree
			if to > end {
				to = end
			}
			if host_offset >= to {
				break
			}
			util.ParallelRun(config, host_arr, host_offset, to, options.Tmpdir)
			host_offset += options.ParallelDegree
		}
	} else {
		util.SerialRun(config, host_arr, host_offset, len(host_arr))
	}

	fmt.Fprintln(os.Stderr, "\033[32mFinished!\033[0m")
}