func main() { wd, err := os.Getwd() if len(os.Args) < 2 { fmt.Println("usage: warp [packages]. warp will OVERWRITE your packages. Use source control.") os.Exit(2) } if err != nil { fatal(err) } var files []string for _, path := range os.Args[1:] { pkg, err := build.Import(path, wd, 0) if err != nil { fatal(err) } for _, file := range pkg.GoFiles { files = append(files, filepath.Join(pkg.Dir, file)) } } for _, fname := range files { if fname == "--" { continue } if !strings.HasSuffix(fname, ".go") || strings.HasSuffix(fname, "_test.go") { continue } fset := token.NewFileSet() f, err := parser.ParseFile(fset, fname, nil, parser.ParseComments) if err != nil { continue } var altered bool for _, d := range f.Decls { fun, ok := d.(*ast.FuncDecl) if !ok { continue } if !ast.IsExported(fun.Name.String()) { continue } // Assume that "io" is imported as "io" if at all. // Spares a long diversion through go.types and/or ugly import processing. for _, arg := range fun.Type.Params.List { sel, ok := arg.Type.(*ast.SelectorExpr) if !ok || sel.Sel.String() != "Reader" || len(arg.Names) != 1 { continue } id, ok := sel.X.(*ast.Ident) if !ok || id.Name != "io" { continue } name := arg.Names[0] assign := warpReader(name, fun.Body.Pos()) body := []ast.Stmt{assign} fun.Body.List = append(body, fun.Body.List...) altered = true } } if altered { astutil.AddImport(fset, f, "github.com/josharian/warp/warped") c, err := os.Create(fname) if err != nil { fatal(err) } printer.Fprint(c, fset, f) } } }
func fixImports(fset *token.FileSet, f *ast.File) (added []string, err error) { // refs are a set of possible package references currently unsatisified by imports. // first key: either base package (e.g. "fmt") or renamed package // second key: referenced package symbol (e.g. "Println") refs := make(map[string]map[string]bool) // decls are the current package imports. key is base package or renamed package. decls := make(map[string]*ast.ImportSpec) // collect potential uses of packages. var visitor visitFn visitor = visitFn(func(node ast.Node) ast.Visitor { if node == nil { return visitor } switch v := node.(type) { case *ast.ImportSpec: if v.Name != nil { decls[v.Name.Name] = v } else { local := importPathToName(strings.Trim(v.Path.Value, `\"`)) decls[local] = v } case *ast.SelectorExpr: xident, ok := v.X.(*ast.Ident) if !ok { break } if xident.Obj != nil { // if the parser can resolve it, it's not a package ref break } pkgName := xident.Name if refs[pkgName] == nil { refs[pkgName] = make(map[string]bool) } if decls[pkgName] == nil { refs[pkgName][v.Sel.Name] = true } } return visitor }) ast.Walk(visitor, f) // Search for imports matching potential package references. searches := 0 type result struct { ipath string name string err error } results := make(chan result) for pkgName, symbols := range refs { if len(symbols) == 0 { continue // skip over packages already imported } go func(pkgName string, symbols map[string]bool) { ipath, rename, err := findImport(pkgName, symbols) r := result{ipath: ipath, err: err} if rename { r.name = pkgName } results <- r }(pkgName, symbols) searches++ } for i := 0; i < searches; i++ { result := <-results if result.err != nil { return nil, result.err } if result.ipath != "" { if result.name != "" { astutil.AddNamedImport(fset, f, result.name, result.ipath) } else { astutil.AddImport(fset, f, result.ipath) } added = append(added, result.ipath) } } // Nil out any unused ImportSpecs, to be removed in following passes unusedImport := map[string]bool{} for pkg, is := range decls { if refs[pkg] == nil && pkg != "_" && pkg != "." { unusedImport[strings.Trim(is.Path.Value, `"`)] = true } } for ipath := range unusedImport { if ipath == "C" { // Don't remove cgo stuff. continue } astutil.DeleteImport(fset, f, ipath) } return added, nil }
// Transform applies the transformation to the specified parsed file, // whose type information is supplied in info, and returns the number // of replacements that were made. // // It mutates the AST in place (the identity of the root node is // unchanged), and may add nodes for which no type information is // available in info. // // Derived from rewriteFile in $GOROOT/src/cmd/gofmt/rewrite.go. // func (tr *Transformer) Transform(info *types.Info, pkg *types.Package, file *ast.File) int { if !tr.seenInfos[info] { tr.seenInfos[info] = true mergeTypeInfo(&tr.info.Info, info) } tr.currentPkg = pkg tr.nsubsts = 0 if tr.verbose { fmt.Fprintf(os.Stderr, "before: %s\n", astString(tr.fset, tr.before)) fmt.Fprintf(os.Stderr, "after: %s\n", astString(tr.fset, tr.after)) } var f func(rv reflect.Value) reflect.Value f = func(rv reflect.Value) reflect.Value { // don't bother if val is invalid to start with if !rv.IsValid() { return reflect.Value{} } rv = apply(f, rv) e := rvToExpr(rv) if e != nil { savedEnv := tr.env tr.env = make(map[string]ast.Expr) // inefficient! Use a slice of k/v pairs if tr.matchExpr(tr.before, e) { if tr.verbose { fmt.Fprintf(os.Stderr, "%s matches %s", astString(tr.fset, tr.before), astString(tr.fset, e)) if len(tr.env) > 0 { fmt.Fprintf(os.Stderr, " with:") for name, ast := range tr.env { fmt.Fprintf(os.Stderr, " %s->%s", name, astString(tr.fset, ast)) } } fmt.Fprintf(os.Stderr, "\n") } tr.nsubsts++ // Clone the replacement tree, performing parameter substitution. // We update all positions to n.Pos() to aid comment placement. rv = tr.subst(tr.env, reflect.ValueOf(tr.after), reflect.ValueOf(e.Pos())) } tr.env = savedEnv } return rv } file2 := apply(f, reflect.ValueOf(file)).Interface().(*ast.File) // By construction, the root node is unchanged. if file != file2 { panic("BUG") } // Add any necessary imports. // TODO(adonovan): remove no-longer needed imports too. if tr.nsubsts > 0 { pkgs := make(map[string]*types.Package) for obj := range tr.importedObjs { pkgs[obj.Pkg().Path()] = obj.Pkg() } for _, imp := range file.Imports { path, _ := strconv.Unquote(imp.Path.Value) delete(pkgs, path) } delete(pkgs, pkg.Path()) // don't import self // NB: AddImport may completely replace the AST! // It thus renders info and tr.info no longer relevant to file. var paths []string for path := range pkgs { paths = append(paths, path) } sort.Strings(paths) for _, path := range paths { astutil.AddImport(tr.fset, file, path) } } tr.currentPkg = nil return tr.nsubsts }