// Process formats and adjusts imports for the provided file. // If opt is nil the defaults are used. func Process(filename string, src []byte, opt *Options) ([]byte, error) { if opt == nil { opt = &Options{Comments: true, TabIndent: true, TabWidth: 8} } fileSet := token.NewFileSet() file, adjust, err := parse(fileSet, filename, src, opt) if err != nil { return nil, err } _, err = fixImports(fileSet, file) if err != nil { return nil, err } sortImports(fileSet, file) imps := astutil.Imports(fileSet, file) var spacesBefore []string // import paths we need spaces before for _, impSection := range imps { // Within each block of contiguous imports, see if any // import lines are in different group numbers. If so, // we'll need to put a space between them so it's // compatible with gofmt. lastGroup := -1 for _, importSpec := range impSection { importPath, _ := strconv.Unquote(importSpec.Path.Value) groupNum := importGroup(importPath) if groupNum != lastGroup && lastGroup != -1 { spacesBefore = append(spacesBefore, importPath) } lastGroup = groupNum } } printerMode := printer.UseSpaces if opt.TabIndent { printerMode |= printer.TabIndent } printConfig := &printer.Config{Mode: printerMode, Tabwidth: opt.TabWidth} var buf bytes.Buffer err = printConfig.Fprint(&buf, fileSet, file) if err != nil { return nil, err } out := buf.Bytes() if adjust != nil { out = adjust(src, out) } if len(spacesBefore) > 0 { out = addImportSpaces(bytes.NewReader(out), spacesBefore) } return out, nil }
// Process formats and adjusts imports for the provided file. // If opt is nil the defaults are used. func Process(filename string, src []byte, opt *Options) ([]byte, error) { if opt == nil { opt = &Options{Comments: true, TabIndent: true, TabWidth: 8} } fileSet := token.NewFileSet() file, adjust, err := parse(fileSet, filename, src, opt) if err != nil { return nil, err } _, err = fixImports(file) if err != nil { return nil, err } sortImports(fileSet, file) imps := astutil.Imports(fileSet, file) var spacesBefore []string // import paths we need spaces before if len(imps) == 1 { // We have just one block of imports. See if any are in different groups numbers. lastGroup := -1 for _, importSpec := range imps[0] { importPath, _ := strconv.Unquote(importSpec.Path.Value) groupNum := importGroup(importPath) if groupNum != lastGroup && lastGroup != -1 { spacesBefore = append(spacesBefore, importPath) } lastGroup = groupNum } } printerMode := printer.UseSpaces if opt.TabIndent { printerMode |= printer.TabIndent } printConfig := &printer.Config{Mode: printerMode, Tabwidth: opt.TabWidth} var buf bytes.Buffer err = printConfig.Fprint(&buf, fileSet, file) if err != nil { return nil, err } out := buf.Bytes() if adjust != nil { out = adjust(src, out) } if len(spacesBefore) > 0 { out = addImportSpaces(bytes.NewReader(out), spacesBefore) } return out, nil }
// If in == nil, the source is the contents of the file with the given filename. func processFile(filename string, in io.Reader, out io.Writer, stdin bool) error { initModesOnce.Do(initModes) if in == nil { f, err := os.Open(filename) if err != nil { return err } defer f.Close() in = f } src, err := ioutil.ReadAll(in) if err != nil { return err } file, adjust, err := parse(fileSet, filename, src, stdin) if err != nil { return err } added, err := fixImports(file) if err != nil { return err } sortImports(fileSet, file) imps := astutil.Imports(fileSet, file) var spacesBefore []string // import paths we need spaces before if len(imps) == 1 && len(added) > 0 { // We have just one block of imports. See if any are in different groups numbers. lastGroup := -1 for _, importSpec := range imps[0] { importPath, _ := strconv.Unquote(importSpec.Path.Value) groupNum := importGroup(importPath) if groupNum != lastGroup && lastGroup != -1 { spacesBefore = append(spacesBefore, importPath) } lastGroup = groupNum } } var buf bytes.Buffer err = (&printer.Config{Mode: printerMode, Tabwidth: *tabWidth}).Fprint(&buf, fileSet, file) if err != nil { return err } res := buf.Bytes() if adjust != nil { res = adjust(src, res) } if len(spacesBefore) > 0 { res = addImportSpaces(bytes.NewReader(res), spacesBefore) } if !bytes.Equal(src, res) { // formatting has changed if *list { fmt.Fprintln(out, filename) } if *write { err = ioutil.WriteFile(filename, res, 0) if err != nil { return err } } if *doDiff { data, err := diff(src, res) if err != nil { return fmt.Errorf("computing diff: %s", err) } fmt.Printf("diff %s gofmt/%s\n", filename, filename) out.Write(data) } } if !*list && !*write && !*doDiff { _, err = out.Write(res) } return err }
func doRewritePackage(pkg *build.Package, st *rewriteState, opts *rewriteOptions) error { libraryMode := opts.LibraryMode(pkg) abs, err := filepath.Abs(pkg.Dir) if err != nil { return err } fset := token.NewFileSet() var names []string names = append(names, pkg.GoFiles...) names = append(names, pkg.CgoFiles...) files, err := parseFiles(fset, abs, names, parser.ParseComments) if err != nil { return err } // First check if we should keep any original imports in the package due to // the use it makes of the imported pkg (type assertions, etc...). disabled := make(map[string]bool) for _, v := range files { imports := astutil.Imports(fset, v) for _, group := range imports { for _, imp := range group { if unquoted, err := strconv.Unquote(imp.Path.Value); err == nil { m := repositoryRe.FindStringSubmatch(unquoted) if len(m) > 0 && shouldKeepOriginalImport(fset, unquoted, imp, v, opts) { disabled[unquoted] = true } } } } } using := make(map[string]bool) for _, v := range files { imports := astutil.Imports(fset, v) for _, group := range imports { for _, imp := range group { if unquoted, err := strconv.Unquote(imp.Path.Value); err == nil { m := repositoryRe.FindStringSubmatch(unquoted) if len(m) > 0 && !disabled[unquoted] { using[m[0]] = true } } } } } // Now check imports we should rewrite if len(using) == 0 { return nil } var repoNames []string for k := range using { repoNames = append(repoNames, k) } if opts.Verbose { fmt.Printf("package %s uses %d 3rd party repositories: %v\n", pkgName(pkg), len(repoNames), repoNames) } repos, err := st.RequestRepos(repoNames) if err != nil { return err } rewrites := make(map[string]string) addRewrites: for _, v := range repos { var importPath string if libraryMode { if v.Version == 0 { if v.AllowsUnpinned { importPath = v.GoPkgsPath } else { if opts.Verbose { fmt.Printf("ignoring package %s, no versions available\n", v.Path) } continue } } else { importPath = v.VersionImportPath() } } else { if opts.PreferRevisions { importPath = v.RevisionImportPath() } else { importPath = v.VersionImportPath() } } if opts.Interactive { prompt: for { fmt.Printf("rewrite import %s to %s in package %s? (y/N)", v.Path, importPath, pkgName(pkg)) oldState, err := terminal.MakeRaw(0) if err != nil { panic(err) } var buf [1]byte os.Stdin.Read(buf[:]) terminal.Restore(0, oldState) fmt.Print("\n") switch buf[0] { case 'y', 'Y': break prompt case 'n', 'N', '\r': // /r is enter continue addRewrites case '\x03', '\x01': // ctrl+c, ctrl+z os.Exit(0) } } } rewrites[v.Path] = importPath } if len(rewrites) == 0 { return nil } // TODO go get new imports return rewriteImports(fset, pkg, files, rewrites, st, opts) }
func rewriteImports(fset *token.FileSet, pkg *build.Package, files map[string]*ast.File, rewrites map[string]string, st *rewriteState, opts *rewriteOptions) error { for k, v := range files { rewritten := make(map[string]string) imports := astutil.Imports(fset, v) for _, group := range imports { for _, imp := range group { if unquoted, err := strconv.Unquote(imp.Path.Value); err == nil { for rk, rv := range rewrites { if !strings.HasPrefix(unquoted, rk) { continue } newImport := strings.Replace(unquoted, rk, rv, 1) if !opts.DryRun { if err := st.DownloadImport(newImport, opts); err != nil { fmt.Fprintf(os.Stderr, "couldn't download %s, using original", newImport) continue } } rewritten[unquoted] = newImport } } } } if len(rewritten) == 0 { continue } if opts.DryRun || opts.Verbose { if opts.DryRun { fmt.Printf("would rewrite %d imports in %s:\n", len(rewritten), k) } else { fmt.Printf("rewrite %d imports in %s:\n", len(rewritten), k) } for ik, iv := range rewritten { fmt.Printf("\t%s => %s\n", ik, iv) } if opts.DryRun { continue } } for ik, iv := range rewritten { astutil.RewriteImport(fset, v, ik, iv) } // Same as go fmt cfg := &printer.Config{ Tabwidth: 8, Mode: printer.UseSpaces | printer.TabIndent, } var buf bytes.Buffer var data []byte var st os.FileInfo var err error if err = cfg.Fprint(&buf, fset, v); err == nil { if data, err = format.Source(buf.Bytes()); err == nil { if st, err = os.Stat(k); err == nil { err = ioutil.WriteFile(k, data, st.Mode()) } } } if err != nil { fmt.Fprintf(os.Stderr, "error rewriting file %s: %s\n", k, err) } } return nil }