コード例 #1
0
ファイル: extractMethod.go プロジェクト: vpavkin/GoRefactor
func checkScoping(block ast.Node, stmtList []ast.Stmt, declaredInExtracted *st.SymbolTable, globalIdentMap st.IdentifierMap) (bool, *st.SymbolTable) {
	source := getStmtList(block)
	i, found := getIndexOfStmt(stmtList[len(stmtList)-1], source)
	if !found {
		panic("didn't find extracted code's end")
	}
	if i == len(source)-1 {
		return true, nil
	}
	vis := &checkScopingVisitor{declaredInExtracted, st.NewSymbolTable(declaredInExtracted.Package), globalIdentMap}
	for j := i + 1; j < len(source); j++ {
		ast.Walk(vis, source[j])
	}
	if vis.errs.Count() > 0 {
		return false, vis.errs
	}
	return true, nil
}
コード例 #2
0
ファイル: methodsVisitor.go プロジェクト: vpavkin/GoRefactor
func (mv *methodsVisitor) Visit(node ast.Node) (w ast.Visitor) {
	w = mv
	switch f := node.(type) {
	case *ast.FuncDecl:

		fft, cyc := st.GetBaseType(mv.Parser.parseTypeSymbol(f.Type))
		if cyc {
			panic("unexpected cycle")
		}
		ft := fft.(*st.FunctionTypeSymbol)
		locals := st.NewSymbolTable(mv.Parser.Package)

		locals.AddOpenedScope(ft.Parameters)
		locals.AddOpenedScope(ft.Results)
		locals.AddOpenedScope(ft.Reciever)

		var basertype, rtype st.ITypeSymbol
		if f.Recv != nil {

			e_count := 0
			for _, field := range f.Recv.List {
				basertype = mv.Parser.parseTypeSymbol(field.Type)
				if prtype, ok := basertype.(*st.PointerTypeSymbol); ok {
					rtype = prtype.BaseType
				} else {
					rtype = basertype
				}

				if mv.Parser.Package.AstPackage.Name == "os" {
					// 					fmt.Printf("###@@@### (%s) %s\n", rtype.Name(), f.Name.Name)
				}
				if rtype.Methods() == nil {
					panic("ok, this is a test panic")
					rtype.SetMethods(st.NewSymbolTable(mv.Parser.Package))
				}

				if len(field.Names) == 0 {
					toAdd := st.MakeVariable("$unnamed receiver"+strconv.Itoa(e_count), ft.Reciever, basertype)
					ft.Reciever.AddSymbol(toAdd)
					e_count += 1
				}

				for _, name := range field.Names {

					toAdd := st.MakeVariable(name.Name, ft.Reciever, basertype)
					mv.Parser.registerIdent(toAdd, name)
					ft.Reciever.AddSymbol(toAdd)
				}
			}
		}

		toAdd := st.MakeFunction(f.Name.Name, nil, ft) // Scope is set 5 lines down
		toAdd.Locals = locals

		mv.Parser.registerIdent(toAdd, f.Name)

		if f.Recv != nil {
			rtype.AddMethod(toAdd)
			toAdd.Scope_ = rtype.Methods()
		} else {
			mv.Parser.RootSymbolTable.AddSymbol(toAdd)
			toAdd.Scope_ = mv.Parser.RootSymbolTable
		}
	}
	return
}
コード例 #3
0
ファイル: extractMethod.go プロジェクト: vpavkin/GoRefactor
func extractMethod(programTree *program.Program, filename string, lineStart int, colStart int, lineEnd int, colEnd int, methodName string, recieverVarLine int, recieverVarCol int) (bool, *errors.GoRefactorError) {

	if ok, err := CheckExtractMethodParameters(filename, lineStart, colStart, lineEnd, colEnd, methodName, recieverVarLine, recieverVarCol); !ok {
		return false, err
	}

	pack, file := programTree.FindPackageAndFileByFilename(filename)
	if pack == nil {
		return false, errors.ArgumentError("filename", "Program packages don't contain file '"+filename+"'")
	}
	fset := pack.FileSet

	recvSym, err := getRecieverSymbol(programTree, pack, filename, recieverVarLine, recieverVarCol)
	if err != nil {
		return false, err
	}

	if recvSym != nil {
		if recvSym.VariableType.Methods() != nil {
			if _, ok := recvSym.VariableType.Methods().LookUp(methodName, ""); ok {
				return false, errors.ArgumentError("methodName", "reciever already contains a method with name "+methodName)
			}
		}
		switch t := recvSym.VariableType.(type) {
		case *st.StructTypeSymbol:
			if _, ok := t.Fields.LookUp(methodName, ""); ok {
				return false, errors.ArgumentError("methodName", "reciever already contains a field with name "+methodName)
			}
		case *st.PointerTypeSymbol:
			if _, ok := t.Fields.LookUp(methodName, ""); ok {
				return false, errors.ArgumentError("methodName", "reciever already contains a field with name "+methodName)
			}
		}
	} else {
		if _, ok := pack.Symbols.LookUp(methodName, ""); ok {
			return false, errors.ArgumentError("methodName", "package already contains a symbol with name "+methodName)
		}
	}

	stmtList, nodeFrom, err := getExtractedStatementList(pack, file, filename, lineStart, colStart, lineEnd, colEnd)
	if err != nil {
		return false, err
	}
	fmt.Printf("list pos,end = %d,%d\n", stmtList[0].Pos(), stmtList[len(stmtList)-1].End())

	params, declared := getParametersAndDeclaredIn(pack, stmtList, programTree)
	fmt.Printf("list pos,end = %d,%d\n", stmtList[0].Pos(), stmtList[len(stmtList)-1].End())
	if recvSym != nil {
		if _, found := params.LookUp(recvSym.Name(), ""); !found {
			return false, &errors.GoRefactorError{ErrorType: "extract method error", Message: "symbol, desired to be reciever, is not a parameter to extracted code"}
		}
		params.RemoveSymbol(recvSym.Name())
	}

	resultList := getResultList(programTree, pack, filename, stmtList)
	results := st.NewSymbolTable(pack)
	for _, r := range resultList {
		results.AddSymbol(st.MakeVariable(st.NO_NAME, results, r))
	}
	fmt.Printf("list pos,end = %d,%d\n", stmtList[0].Pos(), stmtList[len(stmtList)-1].End())
	pointerSymbols := getPointerPassedSymbols(stmtList, params, programTree.IdentMap)
	fmt.Printf("list pos,end = %d,%d\n", stmtList[0].Pos(), stmtList[len(stmtList)-1].End())
	for s, depth := range pointerSymbols {
		println(s.Name(), depth)
	}

	applyPointerTransform(fset, file, stmtList, pointerSymbols, programTree.IdentMap)
	fmt.Printf("list pos,end = %d,%d\n", stmtList[0].Pos(), stmtList[len(stmtList)-1].End())
	fdecl := makeFuncDecl(methodName, stmtList, params, pointerSymbols, results, recvSym, pack, filename)

	if nodeFrom != nil {

		callExpr, callExprLen := makeCallExpr(methodName, params, pointerSymbols, stmtList[0].Pos(), recvSym, pack, filename)

		if ok, errs := checkScoping(nodeFrom, stmtList, declared, programTree.IdentMap); !ok {
			s := ""
			errs.ForEach(func(sym st.Symbol) {
				s += sym.Name() + " "
			})
			return false, &errors.GoRefactorError{ErrorType: "extract method error", Message: "extracted code declares symbols that are used in not-extracted code: " + s}
		}

		app := callExprLen - int(stmtList[len(stmtList)-1].End()-stmtList[0].Pos())
		if app > 0 {

			S, E := fset.Position(nodeFrom.Pos()), fset.Position(nodeFrom.End())
			poses, ends := make([]token.Position, len(stmtList)), make([]token.Position, len(stmtList))
			for i, stmt := range stmtList {
				poses[i], ends[i] = fset.Position(stmt.Pos()), fset.Position(stmt.End())
			}

			tfile := printerUtil.GetFileFromFileSet(fset, filename)
			baseMod := tfile.Base()
			fmt.Printf("app = %d,baseMod = %d\n", app, baseMod)
			fset, file = printerUtil.ReparseFile(file, filename, app, programTree.IdentMap)
			tfile = printerUtil.GetFileFromFileSet(fset, filename)
			lines := printerUtil.GetLines(tfile)
			tfile.SetLines(lines[:len(lines)-(app)])

			nodeFrom = printerUtil.FindNode(fset, file, S, E)
			if baseMod != 1 {
				for _, stmt := range stmtList {
					printerUtil.FixPositions(0, 1-baseMod, stmt, true)
				}
				printerUtil.FixPositions(0, 1-baseMod, callExpr, true)
				printerUtil.FixPositions(0, 1-baseMod, fdecl, true)
			}
			stmtList = make([]ast.Stmt, len(stmtList))
			for i, _ := range stmtList {
				stmtList[i] = printerUtil.FindNode(fset, file, poses[i], ends[i]).(ast.Stmt)
			}
		}

		list := getStmtList(nodeFrom)

		ind, found := getIndexOfStmt(stmtList[0], list)
		if !found {
			panic("didn't find replace origin")
		}
		fmt.Printf("stmtList length = %d\n", len(stmtList))
		if ok, err := printerUtil.DeleteNodeList(fset, filename, file, stmtList); !ok {
			return false, err
		}
		list = getStmtList(nodeFrom)
		newList := make([]ast.Stmt, len(list)+1)
		copy(newList, list[0:ind])
		newList[ind] = &ast.ExprStmt{callExpr}
		for i := ind; i < len(list); i++ {
			newList[i+1] = list[i]
		}
		printerUtil.AddLineForRange(fset, filename, callExpr.Pos(), callExpr.End())
		setStmtList(nodeFrom, newList)
		printerUtil.FixPositionsExcept(callExpr.Pos(), callExprLen, file, true, map[ast.Node]bool{callExpr: true})

	} else {

		//stmtList[0] = utils.CopyAstNode(stmtList[0]).(ast.Stmt)
		rs := stmtList[0].(*ast.ReturnStmt)

		callExpr, callExprLen := makeCallExpr(methodName, params, pointerSymbols, rs.Results[0].Pos(), recvSym, pack, filename)
		mod, baseMod := callExprLen-int(rs.Results[len(rs.Results)-1].End()-rs.Results[0].Pos()), 0
		fset, file, baseMod = printerUtil.ModifyLine(pack.FileSet, file, filename, programTree.IdentMap, callExpr.Pos(), mod)

		if baseMod != 1 {
			fmt.Printf("baseMod = %d\n", baseMod)
			printerUtil.FixPositions(0, 1-baseMod, callExpr, true)
			printerUtil.FixPositions(0, 1-baseMod, rs, true)
			printerUtil.FixPositions(0, 1-baseMod, fdecl, true)
		}
		fmt.Printf("results st,end = %d,%d, callExpr pos,end = %d,%d\n", rs.Results[0].Pos(), rs.Results[len(rs.Results)-1].End(), callExpr.Pos(), callExpr.End())
		errs := replaceExprList(fset.Position(rs.Results[0].Pos()), fset.Position(rs.Results[len(rs.Results)-1].End()), []ast.Expr{callExpr}, fset, file)
		if err, ok := errs[EXTRACT_METHOD]; ok {
			return false, err
		}

		fmt.Printf("mod = %d\n", mod)
		printerUtil.FixPositionsExcept(callExpr.Pos(), mod, file, true, map[ast.Node]bool{callExpr: true})
	}
	programTree.SaveFileExplicit(filename, fset, file)
	print("AAAAAA")
	if ok, fset, newF, err := printerUtil.AddDeclExplicit(fset, filename, file, fset, filename, file, fdecl, programTree.IdentMap); !ok {
		return false, err
	} else {
		print("BBBBBB")
		programTree.SaveFileExplicit(filename, fset, newF)
	}
	return true, nil

}
コード例 #4
0
ファイル: program.go プロジェクト: vpavkin/GoRefactor
func ParseProgram(projectDir string, sources map[string]string, specialPackages map[string][]string) *Program {

	program = &Program{st.NewSymbolTable(nil), make(map[string]*st.Package), make(map[*ast.Ident]st.Symbol)}

	initialize()
	for fldr, goPath := range sources {
		packages[fldr] = goPath
	}

	for fldr, goPath := range sources {
		locatePackage(fldr, specialPackages[goPath])
	}

	packs := new(vector.Vector)
	for _, pack := range program.Packages {
		packs.Push(pack)
	}

	// Recursively fills program.Packages map.
	for _, ppack := range *packs {
		pack := ppack.(*st.Package)
		parseImports(pack, specialPackages)
	}

	for _, pack := range program.Packages {
		if IsGoSrcPackage(pack) {
			pack.IsGoPackage = true
		}
	}

	for _, pack := range program.Packages {

		pack.Symbols.AddOpenedScope(program.BaseSymbolTable)
		go packageParser.ParsePackage(pack, program.IdentMap)
	}
	for _, pack := range program.Packages {
		pack.Communication <- 0
		<-pack.Communication
	}
	// type resolving
	// 	for _, pack := range program.Packages {
	//
	// 	}
	for _, pack := range program.Packages {
		pack.Communication <- 0
		<-pack.Communication
	}
	// 	for _, pack := range program.Packages {
	//
	// 	}
	// 	fmt.Printf("===================All packages stopped fixing \n")

	for _, pack := range program.Packages {
		pack.Communication <- 0
		<-pack.Communication
	}

	// 	for _, pack := range program.Packages {
	//
	// 	}
	// 	fmt.Printf("===================All packages stopped opening \n")

	for _, pack := range program.Packages {
		pack.Communication <- 0
		<-pack.Communication
	}

	// 	for _, pack := range program.Packages {
	//
	// 	}
	// 	fmt.Printf("===================All packages stopped parsing globals \n")
	for _, pack := range program.Packages {
		pack.Communication <- 0
		<-pack.Communication
	}

	// 	for _, pack := range program.Packages {
	//
	// 	}
	// 	fmt.Printf("===================All packages stopped fixing globals \n")
	for _, pack := range program.Packages {
		pack.Communication <- 0
		<-pack.Communication
	}

	// 	for _, pack := range program.Packages {
	//
	// 	}
	// 	fmt.Printf("===================All packages stopped parsing locals \n")

	return program
}
コード例 #5
0
ファイル: localsVisitor.go プロジェクト: vpavkin/GoRefactor
func (lv *innerScopeVisitor) parseBlockStmt(node interface{}) (w ast.Visitor) {
	if node == nil {
		return nil
	}
	w = lv
	table := st.NewSymbolTable(lv.Parser.Package)
	// 	fmt.Printf(" %p %p %p \n", lv.Parser.CurrentSymbolTable, lv.Current, lv.Method.Locals)
	table.AddOpenedScope(lv.Current)
	ww := &innerScopeVisitor{lv.Method, table, lv.Parser, nil, lv.LabelsData}

	temp := lv.Parser.CurrentSymbolTable
	lv.Parser.CurrentSymbolTable = table
	defer func() { lv.Parser.CurrentSymbolTable = temp }()

	switch inNode := node.(type) {
	case *ast.ForStmt:
		ww.parseStmt(inNode.Init)
		ww.Parser.parseExpr(inNode.Cond)
		ww.parseStmt(inNode.Post)
		ast.Walk(ww, inNode.Body)
		w = nil
	case *ast.IfStmt:
		ww.parseStmt(inNode.Init)
		ww.Parser.parseExpr(inNode.Cond)
		ww1 := &innerScopeVisitor{lv.Method, st.NewSymbolTable(lv.Parser.Package), lv.Parser, nil, lv.LabelsData}
		ww2 := &innerScopeVisitor{lv.Method, st.NewSymbolTable(lv.Parser.Package), lv.Parser, nil, lv.LabelsData}
		ww1.Current.AddOpenedScope(ww.Current)
		ww2.Current.AddOpenedScope(ww.Current)
		ast.Walk(ww1, inNode.Body)
		ast.Walk(ww2, inNode.Else)
		w = nil
	case *ast.RangeStmt:
		rangeType := ww.Parser.parseExpr(inNode.X).At(0).(st.ITypeSymbol)
		// 		fmt.Printf("range type = %s, %T\n", rangeType.Name(), rangeType)
		switch inNode.Tok {
		case token.DEFINE:
			if rangeType, _ = st.GetBaseType(rangeType); rangeType == nil {
				panic("unexpected cycle")
			}

			var kT, vT st.ITypeSymbol
			switch rT := rangeType.(type) {
			case *st.ArrayTypeSymbol:
				kT = st.PredeclaredTypes["int"]
				vT = rT.ElemType
			case *st.MapTypeSymbol:
				kT = rT.KeyType
				vT = rT.ValueType
			case *st.BasicTypeSymbol: //string
				kT = st.PredeclaredTypes["int"]
				vT = st.PredeclaredTypes["byte"]
			case *st.ChanTypeSymbol:
				kT = rT.ValueType
			case *st.UnresolvedTypeSymbol:
				panic("unresolved at range")
			}
			iK := inNode.Key.(*ast.Ident)
			if iK.Name != "_" {

				toAdd := st.MakeVariable(iK.Name, ww.Current, kT)
				ww.Parser.registerIdent(toAdd, iK)
				ww.Current.AddSymbol(toAdd)

				// 				fmt.Printf("range key added %s %T\n", toAdd.Name(), toAdd)
			}
			if inNode.Value != nil { // not channel, two range vars
				iV := inNode.Value.(*ast.Ident)
				if iV.Name != "_" {
					toAdd := st.MakeVariable(iV.Name, ww.Current, vT)
					ww.Parser.registerIdent(toAdd, iV)
					ww.Current.AddSymbol(toAdd)
					// 					fmt.Printf("range value added %s %T\n", toAdd.Name(), toAdd)
				}
			}
		case token.ASSIGN:
			ww.Parser.parseExpr(inNode.Key)
			if inNode.Value != nil {
				ww.Parser.parseExpr(inNode.Value)
			}
		}
		ast.Walk(ww, inNode.Body)
		// 		fmt.Printf("end of range\n")
		w = nil
	case *ast.SelectStmt:
		w = ww
	case *ast.SwitchStmt:
		ww.parseStmt(inNode.Init)
		ww.Parser.parseExpr(inNode.Tag)
		ast.Walk(ww, inNode.Body)
		w = nil
	case *ast.TypeSwitchStmt:
		ww.parseStmt(inNode.Init)
		switch tsT := inNode.Assign.(type) {
		case *ast.AssignStmt:
			tsVar := tsT.Lhs[0].(*ast.Ident)
			tsTypeAss := tsT.Rhs[0].(*ast.TypeAssertExpr)
			tsType := ww.Parser.parseExpr(tsTypeAss.X).At(0).(st.ITypeSymbol)

			toAdd := st.MakeVariable(tsVar.Name, ww.Current, tsType)
			toAdd.IsTypeSwitchVar = true
			ww.Parser.registerIdent(toAdd, tsVar)
			ww.Current.AddSymbol(toAdd)
		case *ast.ExprStmt:
			tsTypeAss := tsT.X.(*ast.TypeAssertExpr)
			ww.Parser.parseExpr(tsTypeAss.X)
		}
		ast.Walk(ww, inNode.Body)
		w = nil
	case *ast.CaseClause:

		if inNode.List != nil {
			s := ww.Parser.parseExpr(inNode.List[0]).At(0)
			_, isTypeSwitch := s.(st.ITypeSymbol)

			if isTypeSwitch {
				switch {
				case len(inNode.List) == 1:
					tsType := ww.Parser.parseExpr(inNode.List[0]).At(0).(st.ITypeSymbol)
					if tsVar, ok := lv.Current.FindTypeSwitchVar(); ok {

						toAdd := st.MakeVariable(tsVar.Name(), ww.Current, tsType)
						toAdd.Idents = tsVar.Idents
						toAdd.Posits = tsVar.Posits
						//No position, just register symbol
						ww.Current.AddSymbol(toAdd)
					}
				case len(inNode.List) > 1:
					for _, t := range inNode.List {
						ww.Parser.parseExpr(t)
					}
				}
			} else {
				for _, v := range inNode.List {
					ww.Parser.parseExpr(v)
				}
			}
		}

		for _, stmt := range inNode.Body {
			ast.Walk(ww, stmt)
		}
		w = nil
	case *ast.CommClause:
		ww.parseStmt(inNode.Comm)

		for _, stmt := range inNode.Body {
			ast.Walk(ww, stmt)
		}
		w = nil
	case *ast.FuncLit:
		meth := st.MakeFunction("#", ww.Current, lv.Parser.parseTypeSymbol(inNode.Type))
		meth.Locals = st.NewSymbolTable(ww.Parser.Package)
		meth.Locals.AddOpenedScope(lv.Current)
		if meth.FunctionType.(*st.FunctionTypeSymbol).Parameters != nil {
			meth.Locals.AddOpenedScope(meth.FunctionType.(*st.FunctionTypeSymbol).Parameters)
		}
		if meth.FunctionType.(*st.FunctionTypeSymbol).Results != nil {
			meth.Locals.AddOpenedScope(meth.FunctionType.(*st.FunctionTypeSymbol).Results)
		}
		w = &innerScopeVisitor{meth, meth.Locals, lv.Parser, nil, lv.LabelsData}
	}
	return w
}