Beispiel #1
0
// Process formats and adjusts imports for the provided file.
// If opt is nil the defaults are used.
func Process(filename string, src []byte, opt *Options) ([]byte, error) {
	if opt == nil {
		opt = &Options{Comments: true, TabIndent: true, TabWidth: 8}
	}

	fileSet := token.NewFileSet()
	file, adjust, err := parse(fileSet, filename, src, opt)
	if err != nil {
		return nil, err
	}

	_, err = fixImports(fileSet, file)
	if err != nil {
		return nil, err
	}

	sortImports(fileSet, file)
	imps := astutil.Imports(fileSet, file)

	var spacesBefore []string // import paths we need spaces before
	for _, impSection := range imps {
		// Within each block of contiguous imports, see if any
		// import lines are in different group numbers. If so,
		// we'll need to put a space between them so it's
		// compatible with gofmt.
		lastGroup := -1
		for _, importSpec := range impSection {
			importPath, _ := strconv.Unquote(importSpec.Path.Value)
			groupNum := importGroup(importPath)
			if groupNum != lastGroup && lastGroup != -1 {
				spacesBefore = append(spacesBefore, importPath)
			}
			lastGroup = groupNum
		}

	}

	printerMode := printer.UseSpaces
	if opt.TabIndent {
		printerMode |= printer.TabIndent
	}
	printConfig := &printer.Config{Mode: printerMode, Tabwidth: opt.TabWidth}

	var buf bytes.Buffer
	err = printConfig.Fprint(&buf, fileSet, file)
	if err != nil {
		return nil, err
	}
	out := buf.Bytes()
	if adjust != nil {
		out = adjust(src, out)
	}
	if len(spacesBefore) > 0 {
		out = addImportSpaces(bytes.NewReader(out), spacesBefore)
	}
	return out, nil
}
Beispiel #2
0
// Process formats and adjusts imports for the provided file.
// If opt is nil the defaults are used.
func Process(filename string, src []byte, opt *Options) ([]byte, error) {
	if opt == nil {
		opt = &Options{Comments: true, TabIndent: true, TabWidth: 8}
	}

	fileSet := token.NewFileSet()
	file, adjust, err := parse(fileSet, filename, src, opt)
	if err != nil {
		return nil, err
	}

	_, err = fixImports(file)
	if err != nil {
		return nil, err
	}

	sortImports(fileSet, file)
	imps := astutil.Imports(fileSet, file)

	var spacesBefore []string // import paths we need spaces before
	if len(imps) == 1 {
		// We have just one block of imports. See if any are in different groups numbers.
		lastGroup := -1
		for _, importSpec := range imps[0] {
			importPath, _ := strconv.Unquote(importSpec.Path.Value)
			groupNum := importGroup(importPath)
			if groupNum != lastGroup && lastGroup != -1 {
				spacesBefore = append(spacesBefore, importPath)
			}
			lastGroup = groupNum
		}

	}

	printerMode := printer.UseSpaces
	if opt.TabIndent {
		printerMode |= printer.TabIndent
	}
	printConfig := &printer.Config{Mode: printerMode, Tabwidth: opt.TabWidth}

	var buf bytes.Buffer
	err = printConfig.Fprint(&buf, fileSet, file)
	if err != nil {
		return nil, err
	}
	out := buf.Bytes()
	if adjust != nil {
		out = adjust(src, out)
	}
	if len(spacesBefore) > 0 {
		out = addImportSpaces(bytes.NewReader(out), spacesBefore)
	}
	return out, nil
}
Beispiel #3
0
// If in == nil, the source is the contents of the file with the given filename.
func processFile(filename string, in io.Reader, out io.Writer, stdin bool) error {
	initModesOnce.Do(initModes)

	if in == nil {
		f, err := os.Open(filename)
		if err != nil {
			return err
		}
		defer f.Close()
		in = f
	}

	src, err := ioutil.ReadAll(in)
	if err != nil {
		return err
	}

	file, adjust, err := parse(fileSet, filename, src, stdin)
	if err != nil {
		return err
	}

	added, err := fixImports(file)
	if err != nil {
		return err
	}

	sortImports(fileSet, file)
	imps := astutil.Imports(fileSet, file)

	var spacesBefore []string // import paths we need spaces before
	if len(imps) == 1 && len(added) > 0 {
		// We have just one block of imports. See if any are in different groups numbers.
		lastGroup := -1
		for _, importSpec := range imps[0] {
			importPath, _ := strconv.Unquote(importSpec.Path.Value)
			groupNum := importGroup(importPath)
			if groupNum != lastGroup && lastGroup != -1 {
				spacesBefore = append(spacesBefore, importPath)
			}
			lastGroup = groupNum
		}

	}

	var buf bytes.Buffer
	err = (&printer.Config{Mode: printerMode, Tabwidth: *tabWidth}).Fprint(&buf, fileSet, file)
	if err != nil {
		return err
	}
	res := buf.Bytes()
	if adjust != nil {
		res = adjust(src, res)
	}
	if len(spacesBefore) > 0 {
		res = addImportSpaces(bytes.NewReader(res), spacesBefore)
	}

	if !bytes.Equal(src, res) {
		// formatting has changed
		if *list {
			fmt.Fprintln(out, filename)
		}
		if *write {
			err = ioutil.WriteFile(filename, res, 0)
			if err != nil {
				return err
			}
		}
		if *doDiff {
			data, err := diff(src, res)
			if err != nil {
				return fmt.Errorf("computing diff: %s", err)
			}
			fmt.Printf("diff %s gofmt/%s\n", filename, filename)
			out.Write(data)
		}
	}

	if !*list && !*write && !*doDiff {
		_, err = out.Write(res)
	}

	return err
}
Beispiel #4
0
func doRewritePackage(pkg *build.Package, st *rewriteState, opts *rewriteOptions) error {
	libraryMode := opts.LibraryMode(pkg)
	abs, err := filepath.Abs(pkg.Dir)
	if err != nil {
		return err
	}
	fset := token.NewFileSet()
	var names []string
	names = append(names, pkg.GoFiles...)
	names = append(names, pkg.CgoFiles...)
	files, err := parseFiles(fset, abs, names, parser.ParseComments)
	if err != nil {
		return err
	}
	// First check if we should keep any original imports in the package due to
	// the use it makes of the imported pkg (type assertions, etc...).
	disabled := make(map[string]bool)
	for _, v := range files {
		imports := astutil.Imports(fset, v)
		for _, group := range imports {
			for _, imp := range group {
				if unquoted, err := strconv.Unquote(imp.Path.Value); err == nil {
					m := repositoryRe.FindStringSubmatch(unquoted)
					if len(m) > 0 && shouldKeepOriginalImport(fset, unquoted, imp, v, opts) {
						disabled[unquoted] = true
					}
				}
			}
		}
	}
	using := make(map[string]bool)
	for _, v := range files {
		imports := astutil.Imports(fset, v)
		for _, group := range imports {
			for _, imp := range group {
				if unquoted, err := strconv.Unquote(imp.Path.Value); err == nil {
					m := repositoryRe.FindStringSubmatch(unquoted)
					if len(m) > 0 && !disabled[unquoted] {
						using[m[0]] = true
					}
				}
			}
		}
	}
	// Now check imports we should rewrite
	if len(using) == 0 {
		return nil
	}
	var repoNames []string
	for k := range using {
		repoNames = append(repoNames, k)
	}
	if opts.Verbose {
		fmt.Printf("package %s uses %d 3rd party repositories: %v\n", pkgName(pkg), len(repoNames), repoNames)
	}
	repos, err := st.RequestRepos(repoNames)
	if err != nil {
		return err
	}
	rewrites := make(map[string]string)
addRewrites:
	for _, v := range repos {
		var importPath string
		if libraryMode {
			if v.Version == 0 {
				if v.AllowsUnpinned {
					importPath = v.GoPkgsPath
				} else {
					if opts.Verbose {
						fmt.Printf("ignoring package %s, no versions available\n", v.Path)
					}
					continue
				}
			} else {
				importPath = v.VersionImportPath()
			}
		} else {
			if opts.PreferRevisions {
				importPath = v.RevisionImportPath()
			} else {
				importPath = v.VersionImportPath()
			}
		}
		if opts.Interactive {
		prompt:
			for {
				fmt.Printf("rewrite import %s to %s in package %s? (y/N)", v.Path, importPath, pkgName(pkg))
				oldState, err := terminal.MakeRaw(0)
				if err != nil {
					panic(err)
				}
				var buf [1]byte
				os.Stdin.Read(buf[:])
				terminal.Restore(0, oldState)
				fmt.Print("\n")
				switch buf[0] {
				case 'y', 'Y':
					break prompt
				case 'n', 'N', '\r': // /r is enter
					continue addRewrites
				case '\x03', '\x01':
					// ctrl+c, ctrl+z
					os.Exit(0)
				}
			}

		}
		rewrites[v.Path] = importPath
	}
	if len(rewrites) == 0 {
		return nil
	}
	// TODO go get new imports
	return rewriteImports(fset, pkg, files, rewrites, st, opts)
}
Beispiel #5
0
func rewriteImports(fset *token.FileSet, pkg *build.Package, files map[string]*ast.File, rewrites map[string]string, st *rewriteState, opts *rewriteOptions) error {
	for k, v := range files {
		rewritten := make(map[string]string)
		imports := astutil.Imports(fset, v)
		for _, group := range imports {
			for _, imp := range group {
				if unquoted, err := strconv.Unquote(imp.Path.Value); err == nil {
					for rk, rv := range rewrites {
						if !strings.HasPrefix(unquoted, rk) {
							continue
						}
						newImport := strings.Replace(unquoted, rk, rv, 1)
						if !opts.DryRun {
							if err := st.DownloadImport(newImport, opts); err != nil {
								fmt.Fprintf(os.Stderr, "couldn't download %s, using original", newImport)
								continue
							}
						}
						rewritten[unquoted] = newImport
					}
				}
			}
		}
		if len(rewritten) == 0 {
			continue
		}
		if opts.DryRun || opts.Verbose {
			if opts.DryRun {
				fmt.Printf("would rewrite %d imports in %s:\n", len(rewritten), k)
			} else {
				fmt.Printf("rewrite %d imports in %s:\n", len(rewritten), k)
			}
			for ik, iv := range rewritten {
				fmt.Printf("\t%s => %s\n", ik, iv)
			}
			if opts.DryRun {
				continue
			}
		}
		for ik, iv := range rewritten {
			astutil.RewriteImport(fset, v, ik, iv)
		}
		// Same as go fmt
		cfg := &printer.Config{
			Tabwidth: 8,
			Mode:     printer.UseSpaces | printer.TabIndent,
		}
		var buf bytes.Buffer
		var data []byte
		var st os.FileInfo
		var err error
		if err = cfg.Fprint(&buf, fset, v); err == nil {
			if data, err = format.Source(buf.Bytes()); err == nil {
				if st, err = os.Stat(k); err == nil {
					err = ioutil.WriteFile(k, data, st.Mode())
				}
			}
		}
		if err != nil {
			fmt.Fprintf(os.Stderr, "error rewriting file %s: %s\n", k, err)
		}

	}
	return nil
}