Exemple #1
0
// FormatCode runs "goimports -w" on the source file.
func (f *SourceFile) FormatCode() error {
	// Parse file into AST
	fset := token.NewFileSet()
	file, err := parser.ParseFile(fset, f.Abs(), nil, parser.ParseComments)
	if err != nil {
		content, _ := ioutil.ReadFile(f.Abs())
		var buf bytes.Buffer
		scanner.PrintError(&buf, err)
		return fmt.Errorf("%s\n========\nContent:\n%s", buf.String(), content)
	}
	// Clean unused imports
	imports := astutil.Imports(fset, file)
	for _, group := range imports {
		for _, imp := range group {
			path := strings.Trim(imp.Path.Value, `"`)
			if !astutil.UsesImport(file, path) {
				if imp.Name != nil {
					astutil.DeleteNamedImport(fset, file, imp.Name.Name, path)
				} else {
					astutil.DeleteImport(fset, file, path)
				}
			}
		}
	}
	ast.SortImports(fset, file)
	// Open file to be written
	w, err := os.OpenFile(f.Abs(), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.ModePerm)
	if err != nil {
		return err
	}
	defer w.Close()
	// Write formatted code without unused imports
	return format.Node(w, fset, file)
}
Exemple #2
0
func init() {
	act(Action{
		Path: "/fmt",
		Doc: `
formats the source like gofmt does
@data: {"fn": "...", "src": "..."}
@resp: "formatted source"
`,
		Func: func(r Request) (data, error) {
			a := AcFmtArgs{
				TabIndent: true,
				TabWidth:  8,
			}

			res := ""
			if err := r.Decode(&a); err != nil {
				return res, err
			}

			fset, af, err := parseAstFile(a.Fn, a.Src, parser.ParseComments)
			if err == nil {
				ast.SortImports(fset, af)
				res, err = printSrc(fset, af, a.TabIndent, a.TabWidth)
			}
			return res, err
		},
	})
}
Exemple #3
0
// FormatCode runs "goimports -w" on the source file.
func (f *SourceFile) FormatCode() error {
	if NoFormat {
		return nil
	}
	// Parse file into AST
	fset := token.NewFileSet()
	file, err := parser.ParseFile(fset, f.Abs(), nil, parser.ParseComments)
	if err != nil {
		return err
	}
	// Clean unused imports
	imports := astutil.Imports(fset, file)
	for _, group := range imports {
		for _, imp := range group {
			path := strings.Trim(imp.Path.Value, `"`)
			if !astutil.UsesImport(file, path) {
				astutil.DeleteImport(fset, file, path)
			}
		}
	}
	ast.SortImports(fset, file)
	// Open file to be written
	w, err := os.OpenFile(f.Abs(), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.ModePerm)
	if err != nil {
		return err
	}
	defer w.Close()
	// Write formatted code without unused imports
	return format.Node(w, fset, file)
}
Exemple #4
0
// inspired by godeps rewrite, rewrites import paths with gx vendored names
func rewriteImportsInFile(fi string, rw func(string) string) error {
	cfg := &printer.Config{Mode: printer.UseSpaces | printer.TabIndent, Tabwidth: 8}
	fset := token.NewFileSet()
	file, err := parser.ParseFile(fset, fi, nil, parser.ParseComments)
	if err != nil {
		return err
	}

	var changed bool
	for _, imp := range file.Imports {
		p, err := strconv.Unquote(imp.Path.Value)
		if err != nil {
			return err
		}

		np := rw(p)

		if np != p {
			changed = true
			imp.Path.Value = strconv.Quote(np)
		}
	}

	if !changed {
		return nil
	}

	buf := bufpool.Get().(*bytes.Buffer)
	if err = cfg.Fprint(buf, fset, file); err != nil {
		return err
	}

	fset = token.NewFileSet()
	file, err = parser.ParseFile(fset, fi, buf, parser.ParseComments)
	if err != nil {
		return err
	}

	buf.Reset()
	bufpool.Put(buf)

	ast.SortImports(fset, file)

	wpath := fi + ".temp"
	w, err := os.Create(wpath)
	if err != nil {
		return err
	}

	if err = cfg.Fprint(w, fset, file); err != nil {
		return err
	}

	if err = w.Close(); err != nil {
		return err
	}

	return os.Rename(wpath, fi)
}
Exemple #5
0
func gofmt(fset *token.FileSet, filename string, src *bytes.Buffer) error {
	f, _, err := parse(fset, filename, src.Bytes(), false)
	if err != nil {
		return err
	}
	ast.SortImports(fset, f)
	src.Reset()
	return (&printer.Config{Mode: printerMode, Tabwidth: *tabWidth}).Fprint(src, fset, f)
}
Exemple #6
0
func (m *mFmt) Call() (interface{}, string) {
	res := M{}
	fset, af, err := parseAstFile(m.Fn, m.Src, parser.ParseComments)
	if err == nil {
		ast.SortImports(fset, af)
		res["src"], err = printSrc(fset, af, m.TabIndent, m.TabWidth)
	}
	return res, errStr(err)
}
Exemple #7
0
func gofmtFile(f *ast.File) ([]byte, error) {
	var buf bytes.Buffer

	ast.SortImports(fset, f)
	err := printConfig.Fprint(&buf, fset, f)
	if err != nil {
		return nil, err
	}
	return buf.Bytes(), nil
}
Exemple #8
0
// Given a set of interfaces to mock, write out source code for a package named
// `pkg` that contains mock implementations of those interfaces.
func GenerateMockSource(w io.Writer, pkg string, interfaces []reflect.Type) error {
	// Sanity-check arguments.
	if pkg == "" {
		return errors.New("Package name must be non-empty.")
	}

	if len(interfaces) == 0 {
		return errors.New("List of interfaces must be non-empty.")
	}

	// Make sure each type is indeed an interface.
	for _, it := range interfaces {
		if it.Kind() != reflect.Interface {
			return errors.New("Invalid type: " + it.String())
		}
	}

	// Create an appropriate template arg, then execute the template. Write the
	// raw output into a buffer.
	var arg tmplArg
	arg.Pkg = pkg
	arg.Imports = getImports(interfaces)
	arg.Interfaces = interfaces

	buf := new(bytes.Buffer)
	if err := tmpl.Execute(buf, arg); err != nil {
		return err
	}

	// Parse the output.
	fset := token.NewFileSet()
	astFile, err := parser.ParseFile(fset, pkg+".go", buf, parser.ParseComments)
	if err != nil {
		return errors.New("Error parsing generated code: " + err.Error())
	}

	// Sort the import lines in the AST in the same way that gofmt does.
	ast.SortImports(fset, astFile)

	// Pretty-print the AST, using the same options that gofmt does by default.
	cfg := &printer.Config{
		Mode:     printer.UseSpaces | printer.TabIndent,
		Tabwidth: 8,
	}

	if err = cfg.Fprint(w, fset, astFile); err != nil {
		return errors.New("Error pretty printing: " + err.Error())
	}

	return nil
}
Exemple #9
0
// gofmt takes a Go program, formats it using the standard Go formatting
// rules, and returns it or an error.
func gofmt(body string) (string, error) {
	fset := token.NewFileSet()
	f, err := parser.ParseFile(fset, "prog.go", body, parser.ParseComments)
	if err != nil {
		return "", err
	}
	ast.SortImports(fset, f)
	var buf bytes.Buffer
	err = printer.Fprint(&buf, fset, f)
	if err != nil {
		return "", err
	}
	return buf.String(), nil
}
Exemple #10
0
func gofmt(body string) (string, error) {
	fset := token.NewFileSet()
	f, err := parser.ParseFile(fset, "prog.go", body, parser.ParseComments)
	if err != nil {
		return "", err
	}
	ast.SortImports(fset, f)
	var buf bytes.Buffer
	config := &printer.Config{Mode: printer.UseSpaces | printer.TabIndent, Tabwidth: 8}
	err = config.Fprint(&buf, fset, f)
	if err != nil {
		return "", err
	}
	return buf.String(), nil
}
Exemple #11
0
// Source formats src in canonical gofmt style and returns the result
// or an (I/O or syntax) error. src is expected to be a syntactically
// correct Go source file, or a list of Go declarations or statements.
//
// If src is a partial source file, the leading and trailing space of src
// is applied to the result (such that it has the same leading and trailing
// space as src), and the result is indented by the same amount as the first
// line of src containing code. Imports are not sorted for partial source files.
//
func Source(src []byte) ([]byte, error) {
	fset := token.NewFileSet()
	file, sourceAdj, indentAdj, err := parse(fset, "", src, true)
	if err != nil {
		return nil, err
	}

	if sourceAdj == nil {
		// Complete source file.
		// TODO(gri) consider doing this always.
		ast.SortImports(fset, file)
	}

	return format(fset, file, sourceAdj, indentAdj, src, config)
}
Exemple #12
0
// rewriteGoFile rewrites import statments in the named file
// according to the rules for func qualify.
func rewriteGoFile(name, qual string, paths []string) error {
	debugln("rewriteGoFile", name, ",", qual, ",", paths)
	printerConfig := &printer.Config{Mode: printer.TabIndent | printer.UseSpaces, Tabwidth: 8}
	fset := token.NewFileSet()
	f, err := parser.ParseFile(fset, name, nil, parser.ParseComments)
	if err != nil {
		return err
	}

	var changed bool
	for _, s := range f.Imports {
		name, err := strconv.Unquote(s.Path.Value)
		if err != nil {
			return err // can't happen
		}
		q := qualify(unqualify(name), qual, paths)
		if q != name {
			s.Path.Value = strconv.Quote(q)
			changed = true
		}
	}
	if !changed {
		return nil
	}
	var buffer bytes.Buffer
	if err = printerConfig.Fprint(&buffer, fset, f); err != nil {
		return err
	}
	fset = token.NewFileSet()
	f, err = parser.ParseFile(fset, name, &buffer, parser.ParseComments)
	ast.SortImports(fset, f)
	tpath := name + ".temp"
	t, err := os.Create(tpath)
	if err != nil {
		return err
	}
	if err = printerConfig.Fprint(t, fset, f); err != nil {
		return err
	}
	if err = t.Close(); err != nil {
		return err
	}
	// This is required before the rename on windows.
	if err = os.Remove(name); err != nil {
		return err
	}
	return os.Rename(tpath, name)
}
Exemple #13
0
func write(filePath string, fset *token.FileSet, f *ast.File) error {
	if parentdir := path.Dir(filePath); parentdir != "." {
		if err := os.MkdirAll(parentdir, os.ModePerm); err != nil {
			return err
		}
	}

	file, err := os.Create(filePath)
	if err != nil {
		return err
	}

	ast.SortImports(fset, f)
	err = (&printer.Config{Mode: printerMode, Tabwidth: tabWidth}).Fprint(file, fset, f)
	_ = file.Close()
	return err
}
Exemple #14
0
func scanPackages(filename string) (ret []string) {
	fset := token.NewFileSet()
	f, err := parser.ParseFile(fset, filename, nil, 0)
	if err != nil {
		return
	}
	ast.SortImports(fset, f)
	goroot := os.Getenv("GOROOT")
	for _, imp := range f.Imports {
		pkg := unquote(imp.Path.Value)
		p := filepath.Join(goroot, "src", "pkg", pkg)
		if _, err = os.Stat(p); err != nil {
			ret = appendPkg(ret, imp.Path.Value)
		}
	}
	return ret
}
Exemple #15
0
func goProcessFile(filename string, in io.Reader, out io.Writer) error {
	dest := strings.TrimSuffix(filename, ".go") + ".igo"

	f, err := os.Open(filename)
	if err != nil {
		return err
	}
	defer f.Close()

	src, err := ioutil.ReadAll(f)
	if err != nil {
		return err
	}

	file, adjust, err := goParse(goFileSet, filename, src)
	if err != nil {
		return err
	}

	ast.SortImports(goFileSet, file)

	var buf bytes.Buffer
	err = (&printer.Config{Mode: goPrinterMode, Tabwidth: *tabWidth}).Fprint(&buf, goFileSet, file)
	if err != nil {
		return err
	}
	res := buf.Bytes()
	if adjust != nil {
		res = adjust(src, res)
	}

	if *DestDir != "" {
		dest = filepath.Join(*DestDir, dest)
		createDir(dest)
	}

	err = ioutil.WriteFile(dest, res, 0644)
	if err != nil {
		return err
	}

	return err
}
Exemple #16
0
// Clean writes the clean source to io.Writer. The source can be a io.Reader,
// string or []bytes
func Clean(w io.Writer, src interface{}) error {
	fset := token.NewFileSet()
	file, err := parser.ParseFile(fset, "clean.go", src, parser.ParseComments)
	if err != nil {
		return err
	}
	// Clean unused imports
	imports := astutil.Imports(fset, file)
	for _, group := range imports {
		for _, imp := range group {
			path := strings.Trim(imp.Path.Value, `"`)
			if !astutil.UsesImport(file, path) {
				astutil.DeleteImport(fset, file, path)
			}
		}
	}
	ast.SortImports(fset, file)
	// Write formatted code without unused imports
	return format.Node(w, fset, file)
}
Exemple #17
0
// GoFmt runs `gofmt` to io.Reader and save it as file
// If something wrong it returns error.
func GoFmt(filename string, in io.Reader) error {

	if in == nil {
		f, err := os.Open(filename)
		if err != nil {
			return err
		}
		defer f.Close()
		in = f
	}

	src, err := ioutil.ReadAll(in)
	if err != nil {
		return err
	}
	fileSet := token.NewFileSet()
	file, err := parser.ParseFile(fileSet, filename, src, parser.ParseComments)
	if err != nil {
		return err
	}

	ast.SortImports(fileSet, file)

	var buf bytes.Buffer
	tabWidth := 8
	printerMode := printer.UseSpaces | printer.TabIndent
	err = (&printer.Config{Mode: printerMode, Tabwidth: tabWidth}).Fprint(&buf, fileSet, file)
	if err != nil {
		return err
	}

	res := buf.Bytes()
	if !bytes.Equal(src, res) {
		err = ioutil.WriteFile(filename, res, 0)
		if err != nil {
			return err
		}
	}
	return nil
}
Exemple #18
0
// Node formats node in canonical gofmt style and writes the result to dst.
//
// The node type must be *ast.File, *printer.CommentedNode, []ast.Decl,
// []ast.Stmt, or assignment-compatible to ast.Expr, ast.Decl, ast.Spec,
// or ast.Stmt. Node does not modify node. Imports are not sorted for
// nodes representing partial source files (i.e., if the node is not an
// *ast.File or a *printer.CommentedNode not wrapping an *ast.File).
//
// The function may return early (before the entire result is written)
// and return a formatting error, for instance due to an incorrect AST.
//
func Node(dst io.Writer, fset *token.FileSet, node interface{}) error {
	// Determine if we have a complete source file (file != nil).
	var file *ast.File
	var cnode *printer.CommentedNode
	switch n := node.(type) {
	case *ast.File:
		file = n
	case *printer.CommentedNode:
		if f, ok := n.Node.(*ast.File); ok {
			file = f
			cnode = n
		}
	}

	// Sort imports if necessary.
	if file != nil && hasUnsortedImports(file) {
		// Make a copy of the AST because ast.SortImports is destructive.
		// TODO(gri) Do this more efficiently.
		var buf bytes.Buffer
		err := config.Fprint(&buf, fset, file)
		if err != nil {
			return err
		}
		file, err = parser.ParseFile(fset, "", buf.Bytes(), parserMode)
		if err != nil {
			// We should never get here. If we do, provide good diagnostic.
			return fmt.Errorf("format.Node internal error (%s)", err)
		}
		ast.SortImports(fset, file)

		// Use new file with sorted imports.
		node = file
		if cnode != nil {
			node = &printer.CommentedNode{Node: file, Comments: cnode.Comments}
		}
	}

	return config.Fprint(dst, fset, node)
}
Exemple #19
0
func addImportsToMock(mockAst *ast.File, fset *token.FileSet, imports []*ast.ImportSpec) {
	// Find all the imports we're using in the mockAST
	fi := newFindUsedImports()
	ast.Walk(fi, mockAst)

	// Pick imports out of our input AST that are used in the mock
	usedImports := []ast.Spec{}
	for _, is := range imports {
		if fi.isUsed(is) {
			usedImports = append(usedImports, is)
		}
	}

	if len(usedImports) > 0 {
		// Add these imports into the mock AST
		ai := &addImports{usedImports}
		ast.Walk(ai, mockAst)

		// Sort the imports
		ast.SortImports(fset, mockAst)
	}
}
Exemple #20
0
// RemoveImport will remove an import from source code
func RemoveImport(source, path string) string {
	header, body := header(source)
	if header == "" {
		panic("parse failure")
	}

	src := []byte(header)
	fset := token.NewFileSet()
	f, err := parser.ParseFile(fset, "", src, 0)
	if err != nil {
		panic(err)
	}

	astutil.DeleteImport(fset, f, path)
	ast.SortImports(fset, f)

	var buf bytes.Buffer
	err = printer.Fprint(&buf, fset, f)
	if err != nil {
		panic(err)
	}
	return buf.String() + "\n" + body
}
Exemple #21
0
func WriteFile(filePath, fileOut string) {
	fset := token.NewFileSet()
	f, err := parser.ParseFile(fset, "", fileOut, parser.ParseComments)
	if err != nil {
		fmt.Println("Failed to parse:\n", fileOut)
		panic(err)
	}

	ast.SortImports(fset, f)

	//create parentdir if it doesn't exist
	if parentdir := path.Dir(filePath); parentdir != "." {
		if err := os.MkdirAll(parentdir, os.ModePerm); err != nil {
			panic(err)
		}
	}

	if file, err := os.Create(filePath); err == nil {
		defer file.Close()
		(&printer.Config{Mode: printerMode, Tabwidth: tabWidth}).Fprint(file, fset, f)
	} else {
		panic(err)
	}
}
Exemple #22
0
func rewriteGodepfilesHandler(path string, info os.FileInfo, err error) error {
	name := info.Name()
	if name == "testdata" || name == "vendor" {
		return filepath.SkipDir
	}

	if info.IsDir() {
		return nil
	}

	if e := filepath.Ext(path); e != ".go" {
		return nil
	}

	fset := token.NewFileSet()
	f, err := parser.ParseFile(fset, path, nil, parser.ParseComments)
	if err != nil {
		return err
	}

	var changed bool
	for _, s := range f.Imports {
		n, err := strconv.Unquote(s.Path.Value)
		if err != nil {
			return err
		}
		q := rewriteGodepImport(n)
		if q != name {
			s.Path.Value = strconv.Quote(q)
			changed = true
		}
	}
	if !changed {
		return nil
	}

	printerConfig := &printer.Config{Mode: printer.TabIndent | printer.UseSpaces, Tabwidth: 8}
	var buffer bytes.Buffer
	if err = printerConfig.Fprint(&buffer, fset, f); err != nil {
		return err
	}
	fset = token.NewFileSet()
	f, err = parser.ParseFile(fset, name, &buffer, parser.ParseComments)
	ast.SortImports(fset, f)
	tpath := path + ".temp"
	t, err := os.Create(tpath)
	if err != nil {
		return err
	}
	if err = printerConfig.Fprint(t, fset, f); err != nil {
		return err
	}
	if err = t.Close(); err != nil {
		return err
	}

	msg.Debug("Rewriting Godep imports for %s", path)

	// This is required before the rename on windows.
	if err = os.Remove(path); err != nil {
		return err
	}
	return os.Rename(tpath, path)
}
Exemple #23
0
// Source formats src in canonical gofmt style and returns the result
// or an (I/O or syntax) error. src is expected to be a syntactically
// correct Go source file, or a list of Go declarations or statements.
//
// If src is a partial source file, the leading and trailing space of src
// is applied to the result (such that it has the same leading and trailing
// space as src), and the result is indented by the same amount as the first
// line of src containing code. Imports are not sorted for partial source files.
//
func Source(src []byte) ([]byte, error) {
	fset := token.NewFileSet()
	node, err := parse(fset, src)
	if err != nil {
		return nil, err
	}

	var buf bytes.Buffer
	if file, ok := node.(*ast.File); ok {
		// Complete source file.
		ast.SortImports(fset, file)
		err := config.Fprint(&buf, fset, file)
		if err != nil {
			return nil, err
		}

	} else {
		// Partial source file.
		// Determine and prepend leading space.
		i, j := 0, 0
		for j < len(src) && isSpace(src[j]) {
			if src[j] == '\n' {
				i = j + 1 // index of last line in leading space
			}
			j++
		}
		buf.Write(src[:i])

		// Determine indentation of first code line.
		// Spaces are ignored unless there are no tabs,
		// in which case spaces count as one tab.
		indent := 0
		hasSpace := false
		for _, b := range src[i:j] {
			switch b {
			case ' ':
				hasSpace = true
			case '\t':
				indent++
			}
		}
		if indent == 0 && hasSpace {
			indent = 1
		}

		// Format the source.
		cfg := config
		cfg.Indent = indent
		err := cfg.Fprint(&buf, fset, node)
		if err != nil {
			return nil, err
		}

		// Determine and append trailing space.
		i = len(src)
		for i > 0 && isSpace(src[i-1]) {
			i--
		}
		buf.Write(src[i:])
	}

	return buf.Bytes(), nil
}
Exemple #24
0
// If in == nil, the source is the contents of the file with the given filename.
func processFile(filename string, in io.Reader, out io.Writer, stdin bool) error {
	if in == nil {
		f, err := os.Open(filename)
		if err != nil {
			return err
		}
		defer f.Close()
		in = f
	}

	src, err := ioutil.ReadAll(in)
	if err != nil {
		return err
	}

	file, sourceAdj, indentAdj, err := parse(fileSet, filename, src, stdin)
	if err != nil {
		return err
	}

	if rewrite != nil {
		if sourceAdj == nil {
			file = rewrite(file)
		} else {
			fmt.Fprintf(os.Stderr, "warning: rewrite ignored for incomplete programs\n")
		}
	}

	ast.SortImports(fileSet, file)

	if *simplifyAST {
		simplify(file)
	}

	res, err := format(fileSet, file, sourceAdj, indentAdj, src, printer.Config{Mode: printerMode, Tabwidth: tabWidth})
	if err != nil {
		return err
	}

	if !bytes.Equal(src, res) {
		// formatting has changed
		if *list {
			fmt.Fprintln(out, filename)
		}
		if *write {
			err = ioutil.WriteFile(filename, res, 0644)
			if err != nil {
				return err
			}
		}
		if *doDiff {
			data, err := diff(src, res)
			if err != nil {
				return fmt.Errorf("computing diff: %s", err)
			}
			fmt.Printf("diff %s gofmt/%s\n", filename, filename)
			out.Write(data)
		}
	}

	if !*list && !*write && !*doDiff {
		_, err = out.Write(res)
	}

	return err
}
Exemple #25
0
func imp(fset *token.FileSet, af *ast.File, toggle []ImportDeclArg) *ast.File {
	add := map[ImportDecl]bool{}
	del := map[ImportDecl]bool{}
	for _, sda := range toggle {
		sd := ImportDecl{
			Path: sda.Path,
			Name: sda.Name,
		}
		if sda.Add {
			add[sd] = true
		} else {
			del[sd] = true
		}
	}

	var firstDecl *ast.GenDecl
	imports := map[ImportDecl]bool{}
	for _, decl := range af.Decls {
		if gdecl, ok := decl.(*ast.GenDecl); ok && len(gdecl.Specs) > 0 {
			hasC := false
			sj := 0
			for _, spec := range gdecl.Specs {
				if ispec, ok := spec.(*ast.ImportSpec); ok {
					sd := ImportDecl{
						Path: unquote(ispec.Path.Value),
					}
					if ispec.Name != nil {
						sd.Name = ispec.Name.String()
					}

					if sd.Path == "C" {
						hasC = true
					} else if del[sd] {
						if sj > 0 {
							if lspec, ok := gdecl.Specs[sj-1].(*ast.ImportSpec); ok {
								lspec.EndPos = ispec.Pos()
							}
						}
						continue
					} else {
						imports[sd] = true
					}
				}

				gdecl.Specs[sj] = spec
				sj += 1
			}
			gdecl.Specs = gdecl.Specs[:sj]

			if !hasC && firstDecl == nil {
				firstDecl = gdecl
			}
		}
	}

	if len(add) > 0 {
		if firstDecl == nil {
			firstDecl = &ast.GenDecl{
				Tok:    token.IMPORT,
				Lparen: 1,
			}
			af.Decls = append(af.Decls, firstDecl)
		} else if firstDecl.Lparen == token.NoPos {
			firstDecl.Lparen = 1
		}

		for sd, _ := range add {
			if !imports[sd] {
				ispec := &ast.ImportSpec{
					Path: &ast.BasicLit{
						Value: quote(sd.Path),
						Kind:  token.STRING,
					},
				}
				if sd.Name != "" {
					ispec.Name = &ast.Ident{
						Name: sd.Name,
					}
				}
				firstDecl.Specs = append(firstDecl.Specs, ispec)
				imports[sd] = true
			}
		}
	}

	dj := 0
	for _, decl := range af.Decls {
		if gdecl, ok := decl.(*ast.GenDecl); ok {
			if len(gdecl.Specs) == 0 {
				continue
			}
		}
		af.Decls[dj] = decl
		dj += 1
	}
	af.Decls = af.Decls[:dj]

	ast.SortImports(fset, af)
	return af
}
Exemple #26
0
// Given a set of interfaces to mock, write out source code suitable for
// inclusion in a package with the supplied full package path containing mock
// implementations of those interfaces.
func GenerateMockSource(
	w io.Writer,
	outputPkgPath string,
	interfaces []reflect.Type) (err error) {
	// Sanity-check arguments.
	if outputPkgPath == "" {
		return errors.New("Package path must be non-empty.")
	}

	if len(interfaces) == 0 {
		return errors.New("List of interfaces must be non-empty.")
	}

	// Make sure each type is indeed an interface.
	for _, it := range interfaces {
		if it.Kind() != reflect.Interface {
			return errors.New("Invalid type: " + it.String())
		}
	}

	// Make sure each interface is from the same package.
	interfacePkgPath := interfaces[0].PkgPath()
	for _, t := range interfaces {
		if t.PkgPath() != interfacePkgPath {
			err = fmt.Errorf(
				"Package path mismatch: %q vs. %q",
				interfacePkgPath,
				t.PkgPath())

			return
		}
	}

	// Set up an appropriate template arg.
	arg := tmplArg{
		Interfaces:       interfaces,
		InterfacePkgPath: interfacePkgPath,
		OutputPkgPath:    outputPkgPath,
		Imports:          getImports(interfaces, outputPkgPath),
	}

	// Configure and parse the template.
	tmpl := template.New("code")
	tmpl.Funcs(template.FuncMap{
		"pathBase":           path.Base,
		"getMethods":         getMethods,
		"getInputs":          getInputs,
		"getOutputs":         getOutputs,
		"getInputTypeString": arg.getInputTypeString,
		"getTypeString":      arg.getTypeString,
	})

	_, err = tmpl.Parse(gTmplStr)
	if err != nil {
		err = fmt.Errorf("Parse: %v", err)
		return
	}

	// Execute the template, collecting the raw output into a buffer.
	buf := new(bytes.Buffer)
	if err := tmpl.Execute(buf, arg); err != nil {
		return err
	}

	// Parse the output.
	fset := token.NewFileSet()
	astFile, err := parser.ParseFile(
		fset,
		path.Base(outputPkgPath+".go"),
		buf,
		parser.ParseComments)

	if err != nil {
		err = fmt.Errorf("parser.ParseFile: %v", err)
		return
	}

	// Sort the import lines in the AST in the same way that gofmt does.
	ast.SortImports(fset, astFile)

	// Pretty-print the AST, using the same options that gofmt does by default.
	cfg := &printer.Config{
		Mode:     printer.UseSpaces | printer.TabIndent,
		Tabwidth: 8,
	}

	if err = cfg.Fprint(w, fset, astFile); err != nil {
		return errors.New("Error pretty printing: " + err.Error())
	}

	return nil
}
Exemple #27
0
// If in == nil, the source is the contents of the file with the given filename.
func processFile(filename string, in io.Reader, out io.Writer, stdin bool) error {
	if in == nil {
		f, err := os.Open(filename)
		if err != nil {
			return err
		}
		defer f.Close()
		in = f
	}

	src, err := ioutil.ReadAll(in)
	if err != nil {
		return err
	}

	file, adjust, err := parse(filename, src, stdin)
	if err != nil {
		return err
	}

	if rewrite != nil {
		if adjust == nil {
			file = rewrite(file)
		} else {
			fmt.Fprintf(os.Stderr, "warning: rewrite ignored for incomplete programs\n")
		}
	}

	ast.SortImports(fset, file)

	if *simplifyAST {
		simplify(file)
	}

	var buf bytes.Buffer
	err = (&printer.Config{printerMode, *tabWidth}).Fprint(&buf, fset, file)
	if err != nil {
		return err
	}
	res := buf.Bytes()
	if adjust != nil {
		res = adjust(src, res)
	}

	if !bytes.Equal(src, res) {
		// formatting has changed
		if *list {
			fmt.Fprintln(out, filename)
		}
		if *write {
			err = ioutil.WriteFile(filename, res, 0)
			if err != nil {
				return err
			}
		}
		if *doDiff {
			data, err := diff(src, res)
			if err != nil {
				return fmt.Errorf("computing diff: %s", err)
			}
			fmt.Printf("diff %s gofmt/%s\n", filename, filename)
			out.Write(data)
		}
	}

	if !*list && !*write && !*doDiff {
		_, err = out.Write(res)
	}

	return err
}
Exemple #28
0
// If in == nil, the source is the contents of the file with the given filename.
func processFile(filename string, in io.Reader, out io.Writer, stdin bool) error {
	var perm os.FileMode = 0644
	if in == nil {
		f, err := os.Open(filename)
		if err != nil {
			return err
		}
		defer f.Close()
		fi, err := f.Stat()
		if err != nil {
			return err
		}
		in = f
		perm = fi.Mode().Perm()
	}

	src, err := ioutil.ReadAll(in)
	if err != nil {
		return err
	}

	file, sourceAdj, indentAdj, err := parse(fileSet, filename, src, stdin)
	if err != nil {
		return err
	}

	if rewrite != nil {
		if sourceAdj == nil {
			file = rewrite(file)
		} else {
			fmt.Fprintf(os.Stderr, "warning: rewrite ignored for incomplete programs\n")
		}
	}

	ast.SortImports(fileSet, file)

	if *simplifyAST {
		simplify(file)
	}

	res, err := format(fileSet, file, sourceAdj, indentAdj, src, printer.Config{Mode: printerMode, Tabwidth: tabWidth})
	if err != nil {
		return err
	}

	if !bytes.Equal(src, res) {
		// formatting has changed
		if *list {
			fmt.Fprintln(out, filename)
		}
		if *write {
			// make a temporary backup before overwriting original
			bakname, err := backupFile(filename+".", src, perm)
			if err != nil {
				return err
			}
			err = ioutil.WriteFile(filename, res, perm)
			if err != nil {
				os.Rename(bakname, filename)
				return err
			}
			err = os.Remove(bakname)
			if err != nil {
				return err
			}
		}
		if *doDiff {
			data, err := diff(src, res)
			if err != nil {
				return fmt.Errorf("computing diff: %s", err)
			}
			fmt.Printf("diff %s gofmt/%s\n", filename, filename)
			out.Write(data)
		}
	}

	if !*list && !*write && !*doDiff {
		_, err = out.Write(res)
	}

	return err
}
Exemple #29
0
// If in == nil, the source is the contents of the file with the given filename.
func processFile(filename string, in io.Reader, out io.Writer, stdin bool) error {
	if in == nil {
		f, err := os.Open(filename)
		if err != nil {
			return err
		}
		defer f.Close()
		in = f
	}

	src, err := ioutil.ReadAll(in)
	if err != nil {
		return err
	}

	file, adjust, err := parse(fileSet, filename, src, stdin)
	if err != nil {
		return err
	}

	fixImports(file)

	ast.SortImports(fileSet, file)

	var buf bytes.Buffer
	err = (&printer.Config{Mode: printerMode, Tabwidth: 8}).Fprint(&buf, fileSet, file)
	if err != nil {
		return err
	}
	res := buf.Bytes()
	if adjust != nil {
		res = adjust(src, res)
	}

	if !bytes.Equal(src, res) {
		// formatting has changed
		if *list {
			fmt.Fprintln(out, filename)
		}
		if *write {
			err = ioutil.WriteFile(filename, res, 0)
			if err != nil {
				return err
			}
		}
		if *doDiff {
			data, err := diff(src, res)
			if err != nil {
				return fmt.Errorf("computing diff: %s", err)
			}
			fmt.Printf("diff %s gofmt/%s\n", filename, filename)
			out.Write(data)
		}
	}

	if !*list && !*write && !*doDiff {
		_, err = out.Write(res)
	}

	return err
}