func checkScoping(block ast.Node, stmtList []ast.Stmt, declaredInExtracted *st.SymbolTable, globalIdentMap st.IdentifierMap) (bool, *st.SymbolTable) { source := getStmtList(block) i, found := getIndexOfStmt(stmtList[len(stmtList)-1], source) if !found { panic("didn't find extracted code's end") } if i == len(source)-1 { return true, nil } vis := &checkScopingVisitor{declaredInExtracted, st.NewSymbolTable(declaredInExtracted.Package), globalIdentMap} for j := i + 1; j < len(source); j++ { ast.Walk(vis, source[j]) } if vis.errs.Count() > 0 { return false, vis.errs } return true, nil }
func (mv *methodsVisitor) Visit(node ast.Node) (w ast.Visitor) { w = mv switch f := node.(type) { case *ast.FuncDecl: fft, cyc := st.GetBaseType(mv.Parser.parseTypeSymbol(f.Type)) if cyc { panic("unexpected cycle") } ft := fft.(*st.FunctionTypeSymbol) locals := st.NewSymbolTable(mv.Parser.Package) locals.AddOpenedScope(ft.Parameters) locals.AddOpenedScope(ft.Results) locals.AddOpenedScope(ft.Reciever) var basertype, rtype st.ITypeSymbol if f.Recv != nil { e_count := 0 for _, field := range f.Recv.List { basertype = mv.Parser.parseTypeSymbol(field.Type) if prtype, ok := basertype.(*st.PointerTypeSymbol); ok { rtype = prtype.BaseType } else { rtype = basertype } if mv.Parser.Package.AstPackage.Name == "os" { // fmt.Printf("###@@@### (%s) %s\n", rtype.Name(), f.Name.Name) } if rtype.Methods() == nil { panic("ok, this is a test panic") rtype.SetMethods(st.NewSymbolTable(mv.Parser.Package)) } if len(field.Names) == 0 { toAdd := st.MakeVariable("$unnamed receiver"+strconv.Itoa(e_count), ft.Reciever, basertype) ft.Reciever.AddSymbol(toAdd) e_count += 1 } for _, name := range field.Names { toAdd := st.MakeVariable(name.Name, ft.Reciever, basertype) mv.Parser.registerIdent(toAdd, name) ft.Reciever.AddSymbol(toAdd) } } } toAdd := st.MakeFunction(f.Name.Name, nil, ft) // Scope is set 5 lines down toAdd.Locals = locals mv.Parser.registerIdent(toAdd, f.Name) if f.Recv != nil { rtype.AddMethod(toAdd) toAdd.Scope_ = rtype.Methods() } else { mv.Parser.RootSymbolTable.AddSymbol(toAdd) toAdd.Scope_ = mv.Parser.RootSymbolTable } } return }
func extractMethod(programTree *program.Program, filename string, lineStart int, colStart int, lineEnd int, colEnd int, methodName string, recieverVarLine int, recieverVarCol int) (bool, *errors.GoRefactorError) { if ok, err := CheckExtractMethodParameters(filename, lineStart, colStart, lineEnd, colEnd, methodName, recieverVarLine, recieverVarCol); !ok { return false, err } pack, file := programTree.FindPackageAndFileByFilename(filename) if pack == nil { return false, errors.ArgumentError("filename", "Program packages don't contain file '"+filename+"'") } fset := pack.FileSet recvSym, err := getRecieverSymbol(programTree, pack, filename, recieverVarLine, recieverVarCol) if err != nil { return false, err } if recvSym != nil { if recvSym.VariableType.Methods() != nil { if _, ok := recvSym.VariableType.Methods().LookUp(methodName, ""); ok { return false, errors.ArgumentError("methodName", "reciever already contains a method with name "+methodName) } } switch t := recvSym.VariableType.(type) { case *st.StructTypeSymbol: if _, ok := t.Fields.LookUp(methodName, ""); ok { return false, errors.ArgumentError("methodName", "reciever already contains a field with name "+methodName) } case *st.PointerTypeSymbol: if _, ok := t.Fields.LookUp(methodName, ""); ok { return false, errors.ArgumentError("methodName", "reciever already contains a field with name "+methodName) } } } else { if _, ok := pack.Symbols.LookUp(methodName, ""); ok { return false, errors.ArgumentError("methodName", "package already contains a symbol with name "+methodName) } } stmtList, nodeFrom, err := getExtractedStatementList(pack, file, filename, lineStart, colStart, lineEnd, colEnd) if err != nil { return false, err } fmt.Printf("list pos,end = %d,%d\n", stmtList[0].Pos(), stmtList[len(stmtList)-1].End()) params, declared := getParametersAndDeclaredIn(pack, stmtList, programTree) fmt.Printf("list pos,end = %d,%d\n", stmtList[0].Pos(), stmtList[len(stmtList)-1].End()) if recvSym != nil { if _, found := params.LookUp(recvSym.Name(), ""); !found { return false, &errors.GoRefactorError{ErrorType: "extract method error", Message: "symbol, desired to be reciever, is not a parameter to extracted code"} } params.RemoveSymbol(recvSym.Name()) } resultList := getResultList(programTree, pack, filename, stmtList) results := st.NewSymbolTable(pack) for _, r := range resultList { results.AddSymbol(st.MakeVariable(st.NO_NAME, results, r)) } fmt.Printf("list pos,end = %d,%d\n", stmtList[0].Pos(), stmtList[len(stmtList)-1].End()) pointerSymbols := getPointerPassedSymbols(stmtList, params, programTree.IdentMap) fmt.Printf("list pos,end = %d,%d\n", stmtList[0].Pos(), stmtList[len(stmtList)-1].End()) for s, depth := range pointerSymbols { println(s.Name(), depth) } applyPointerTransform(fset, file, stmtList, pointerSymbols, programTree.IdentMap) fmt.Printf("list pos,end = %d,%d\n", stmtList[0].Pos(), stmtList[len(stmtList)-1].End()) fdecl := makeFuncDecl(methodName, stmtList, params, pointerSymbols, results, recvSym, pack, filename) if nodeFrom != nil { callExpr, callExprLen := makeCallExpr(methodName, params, pointerSymbols, stmtList[0].Pos(), recvSym, pack, filename) if ok, errs := checkScoping(nodeFrom, stmtList, declared, programTree.IdentMap); !ok { s := "" errs.ForEach(func(sym st.Symbol) { s += sym.Name() + " " }) return false, &errors.GoRefactorError{ErrorType: "extract method error", Message: "extracted code declares symbols that are used in not-extracted code: " + s} } app := callExprLen - int(stmtList[len(stmtList)-1].End()-stmtList[0].Pos()) if app > 0 { S, E := fset.Position(nodeFrom.Pos()), fset.Position(nodeFrom.End()) poses, ends := make([]token.Position, len(stmtList)), make([]token.Position, len(stmtList)) for i, stmt := range stmtList { poses[i], ends[i] = fset.Position(stmt.Pos()), fset.Position(stmt.End()) } tfile := printerUtil.GetFileFromFileSet(fset, filename) baseMod := tfile.Base() fmt.Printf("app = %d,baseMod = %d\n", app, baseMod) fset, file = printerUtil.ReparseFile(file, filename, app, programTree.IdentMap) tfile = printerUtil.GetFileFromFileSet(fset, filename) lines := printerUtil.GetLines(tfile) tfile.SetLines(lines[:len(lines)-(app)]) nodeFrom = printerUtil.FindNode(fset, file, S, E) if baseMod != 1 { for _, stmt := range stmtList { printerUtil.FixPositions(0, 1-baseMod, stmt, true) } printerUtil.FixPositions(0, 1-baseMod, callExpr, true) printerUtil.FixPositions(0, 1-baseMod, fdecl, true) } stmtList = make([]ast.Stmt, len(stmtList)) for i, _ := range stmtList { stmtList[i] = printerUtil.FindNode(fset, file, poses[i], ends[i]).(ast.Stmt) } } list := getStmtList(nodeFrom) ind, found := getIndexOfStmt(stmtList[0], list) if !found { panic("didn't find replace origin") } fmt.Printf("stmtList length = %d\n", len(stmtList)) if ok, err := printerUtil.DeleteNodeList(fset, filename, file, stmtList); !ok { return false, err } list = getStmtList(nodeFrom) newList := make([]ast.Stmt, len(list)+1) copy(newList, list[0:ind]) newList[ind] = &ast.ExprStmt{callExpr} for i := ind; i < len(list); i++ { newList[i+1] = list[i] } printerUtil.AddLineForRange(fset, filename, callExpr.Pos(), callExpr.End()) setStmtList(nodeFrom, newList) printerUtil.FixPositionsExcept(callExpr.Pos(), callExprLen, file, true, map[ast.Node]bool{callExpr: true}) } else { //stmtList[0] = utils.CopyAstNode(stmtList[0]).(ast.Stmt) rs := stmtList[0].(*ast.ReturnStmt) callExpr, callExprLen := makeCallExpr(methodName, params, pointerSymbols, rs.Results[0].Pos(), recvSym, pack, filename) mod, baseMod := callExprLen-int(rs.Results[len(rs.Results)-1].End()-rs.Results[0].Pos()), 0 fset, file, baseMod = printerUtil.ModifyLine(pack.FileSet, file, filename, programTree.IdentMap, callExpr.Pos(), mod) if baseMod != 1 { fmt.Printf("baseMod = %d\n", baseMod) printerUtil.FixPositions(0, 1-baseMod, callExpr, true) printerUtil.FixPositions(0, 1-baseMod, rs, true) printerUtil.FixPositions(0, 1-baseMod, fdecl, true) } fmt.Printf("results st,end = %d,%d, callExpr pos,end = %d,%d\n", rs.Results[0].Pos(), rs.Results[len(rs.Results)-1].End(), callExpr.Pos(), callExpr.End()) errs := replaceExprList(fset.Position(rs.Results[0].Pos()), fset.Position(rs.Results[len(rs.Results)-1].End()), []ast.Expr{callExpr}, fset, file) if err, ok := errs[EXTRACT_METHOD]; ok { return false, err } fmt.Printf("mod = %d\n", mod) printerUtil.FixPositionsExcept(callExpr.Pos(), mod, file, true, map[ast.Node]bool{callExpr: true}) } programTree.SaveFileExplicit(filename, fset, file) print("AAAAAA") if ok, fset, newF, err := printerUtil.AddDeclExplicit(fset, filename, file, fset, filename, file, fdecl, programTree.IdentMap); !ok { return false, err } else { print("BBBBBB") programTree.SaveFileExplicit(filename, fset, newF) } return true, nil }
func ParseProgram(projectDir string, sources map[string]string, specialPackages map[string][]string) *Program { program = &Program{st.NewSymbolTable(nil), make(map[string]*st.Package), make(map[*ast.Ident]st.Symbol)} initialize() for fldr, goPath := range sources { packages[fldr] = goPath } for fldr, goPath := range sources { locatePackage(fldr, specialPackages[goPath]) } packs := new(vector.Vector) for _, pack := range program.Packages { packs.Push(pack) } // Recursively fills program.Packages map. for _, ppack := range *packs { pack := ppack.(*st.Package) parseImports(pack, specialPackages) } for _, pack := range program.Packages { if IsGoSrcPackage(pack) { pack.IsGoPackage = true } } for _, pack := range program.Packages { pack.Symbols.AddOpenedScope(program.BaseSymbolTable) go packageParser.ParsePackage(pack, program.IdentMap) } for _, pack := range program.Packages { pack.Communication <- 0 <-pack.Communication } // type resolving // for _, pack := range program.Packages { // // } for _, pack := range program.Packages { pack.Communication <- 0 <-pack.Communication } // for _, pack := range program.Packages { // // } // fmt.Printf("===================All packages stopped fixing \n") for _, pack := range program.Packages { pack.Communication <- 0 <-pack.Communication } // for _, pack := range program.Packages { // // } // fmt.Printf("===================All packages stopped opening \n") for _, pack := range program.Packages { pack.Communication <- 0 <-pack.Communication } // for _, pack := range program.Packages { // // } // fmt.Printf("===================All packages stopped parsing globals \n") for _, pack := range program.Packages { pack.Communication <- 0 <-pack.Communication } // for _, pack := range program.Packages { // // } // fmt.Printf("===================All packages stopped fixing globals \n") for _, pack := range program.Packages { pack.Communication <- 0 <-pack.Communication } // for _, pack := range program.Packages { // // } // fmt.Printf("===================All packages stopped parsing locals \n") return program }
func (lv *innerScopeVisitor) parseBlockStmt(node interface{}) (w ast.Visitor) { if node == nil { return nil } w = lv table := st.NewSymbolTable(lv.Parser.Package) // fmt.Printf(" %p %p %p \n", lv.Parser.CurrentSymbolTable, lv.Current, lv.Method.Locals) table.AddOpenedScope(lv.Current) ww := &innerScopeVisitor{lv.Method, table, lv.Parser, nil, lv.LabelsData} temp := lv.Parser.CurrentSymbolTable lv.Parser.CurrentSymbolTable = table defer func() { lv.Parser.CurrentSymbolTable = temp }() switch inNode := node.(type) { case *ast.ForStmt: ww.parseStmt(inNode.Init) ww.Parser.parseExpr(inNode.Cond) ww.parseStmt(inNode.Post) ast.Walk(ww, inNode.Body) w = nil case *ast.IfStmt: ww.parseStmt(inNode.Init) ww.Parser.parseExpr(inNode.Cond) ww1 := &innerScopeVisitor{lv.Method, st.NewSymbolTable(lv.Parser.Package), lv.Parser, nil, lv.LabelsData} ww2 := &innerScopeVisitor{lv.Method, st.NewSymbolTable(lv.Parser.Package), lv.Parser, nil, lv.LabelsData} ww1.Current.AddOpenedScope(ww.Current) ww2.Current.AddOpenedScope(ww.Current) ast.Walk(ww1, inNode.Body) ast.Walk(ww2, inNode.Else) w = nil case *ast.RangeStmt: rangeType := ww.Parser.parseExpr(inNode.X).At(0).(st.ITypeSymbol) // fmt.Printf("range type = %s, %T\n", rangeType.Name(), rangeType) switch inNode.Tok { case token.DEFINE: if rangeType, _ = st.GetBaseType(rangeType); rangeType == nil { panic("unexpected cycle") } var kT, vT st.ITypeSymbol switch rT := rangeType.(type) { case *st.ArrayTypeSymbol: kT = st.PredeclaredTypes["int"] vT = rT.ElemType case *st.MapTypeSymbol: kT = rT.KeyType vT = rT.ValueType case *st.BasicTypeSymbol: //string kT = st.PredeclaredTypes["int"] vT = st.PredeclaredTypes["byte"] case *st.ChanTypeSymbol: kT = rT.ValueType case *st.UnresolvedTypeSymbol: panic("unresolved at range") } iK := inNode.Key.(*ast.Ident) if iK.Name != "_" { toAdd := st.MakeVariable(iK.Name, ww.Current, kT) ww.Parser.registerIdent(toAdd, iK) ww.Current.AddSymbol(toAdd) // fmt.Printf("range key added %s %T\n", toAdd.Name(), toAdd) } if inNode.Value != nil { // not channel, two range vars iV := inNode.Value.(*ast.Ident) if iV.Name != "_" { toAdd := st.MakeVariable(iV.Name, ww.Current, vT) ww.Parser.registerIdent(toAdd, iV) ww.Current.AddSymbol(toAdd) // fmt.Printf("range value added %s %T\n", toAdd.Name(), toAdd) } } case token.ASSIGN: ww.Parser.parseExpr(inNode.Key) if inNode.Value != nil { ww.Parser.parseExpr(inNode.Value) } } ast.Walk(ww, inNode.Body) // fmt.Printf("end of range\n") w = nil case *ast.SelectStmt: w = ww case *ast.SwitchStmt: ww.parseStmt(inNode.Init) ww.Parser.parseExpr(inNode.Tag) ast.Walk(ww, inNode.Body) w = nil case *ast.TypeSwitchStmt: ww.parseStmt(inNode.Init) switch tsT := inNode.Assign.(type) { case *ast.AssignStmt: tsVar := tsT.Lhs[0].(*ast.Ident) tsTypeAss := tsT.Rhs[0].(*ast.TypeAssertExpr) tsType := ww.Parser.parseExpr(tsTypeAss.X).At(0).(st.ITypeSymbol) toAdd := st.MakeVariable(tsVar.Name, ww.Current, tsType) toAdd.IsTypeSwitchVar = true ww.Parser.registerIdent(toAdd, tsVar) ww.Current.AddSymbol(toAdd) case *ast.ExprStmt: tsTypeAss := tsT.X.(*ast.TypeAssertExpr) ww.Parser.parseExpr(tsTypeAss.X) } ast.Walk(ww, inNode.Body) w = nil case *ast.CaseClause: if inNode.List != nil { s := ww.Parser.parseExpr(inNode.List[0]).At(0) _, isTypeSwitch := s.(st.ITypeSymbol) if isTypeSwitch { switch { case len(inNode.List) == 1: tsType := ww.Parser.parseExpr(inNode.List[0]).At(0).(st.ITypeSymbol) if tsVar, ok := lv.Current.FindTypeSwitchVar(); ok { toAdd := st.MakeVariable(tsVar.Name(), ww.Current, tsType) toAdd.Idents = tsVar.Idents toAdd.Posits = tsVar.Posits //No position, just register symbol ww.Current.AddSymbol(toAdd) } case len(inNode.List) > 1: for _, t := range inNode.List { ww.Parser.parseExpr(t) } } } else { for _, v := range inNode.List { ww.Parser.parseExpr(v) } } } for _, stmt := range inNode.Body { ast.Walk(ww, stmt) } w = nil case *ast.CommClause: ww.parseStmt(inNode.Comm) for _, stmt := range inNode.Body { ast.Walk(ww, stmt) } w = nil case *ast.FuncLit: meth := st.MakeFunction("#", ww.Current, lv.Parser.parseTypeSymbol(inNode.Type)) meth.Locals = st.NewSymbolTable(ww.Parser.Package) meth.Locals.AddOpenedScope(lv.Current) if meth.FunctionType.(*st.FunctionTypeSymbol).Parameters != nil { meth.Locals.AddOpenedScope(meth.FunctionType.(*st.FunctionTypeSymbol).Parameters) } if meth.FunctionType.(*st.FunctionTypeSymbol).Results != nil { meth.Locals.AddOpenedScope(meth.FunctionType.(*st.FunctionTypeSymbol).Results) } w = &innerScopeVisitor{meth, meth.Locals, lv.Parser, nil, lv.LabelsData} } return w }