// replaceStmt returns the (possibly many) statements that should replace // stmt. Generally a stmt is untouched or removed, but in some cases a // single stmt can result in multiple statements. This is usually only the case // when removing a block that was not taken, but pulling out function calls // that were part of the initialization of the block. func (v *trimVisitor) replaceStmt(stmt ast.Stmt) []ast.Stmt { switch stmt := stmt.(type) { case nil: return nil default: // Keep original return []ast.Stmt{stmt} case *ast.RangeStmt: if v.visited(stmt.Body) { return []ast.Stmt{stmt} } call := v.findCall(stmt.X) if call != nil { return []ast.Stmt{&ast.ExprStmt{call}} } return nil case *ast.ForStmt: if v.visited(stmt.Body) { return []ast.Stmt{stmt} } nodes := []*ast.CallExpr{ v.findCall(stmt.Init), v.findCall(stmt.Cond), v.findCall(stmt.Post), } var result []ast.Stmt for _, call := range nodes { if call != nil { result = append(result, &ast.ExprStmt{call}) } } return result case *ast.IfStmt: vIf := v.visited(stmt.Body) vElse := v.visited(stmt.Else) if !vIf { var result []ast.Stmt // If we didn't reach the body, pull out any calls from // init and cond. nodes := []*ast.CallExpr{ v.findCall(stmt.Init), v.findCall(stmt.Cond), } for _, call := range nodes { if call != nil { result = append(result, &ast.ExprStmt{call}) } } if vElse { // We reached the else; add it if block, ok := stmt.Else.(*ast.BlockStmt); ok { // For a block statement, add the statements individually // so we don't end up with an unnecessary block for _, stmt := range block.List { result = append(result, v.replaceStmt(stmt)...) } } else { result = append(result, v.replaceStmt(stmt.Else)...) } } return result } else { // We did take the if body if !vElse { // But not the else: remove it stmt.Else = nil } return []ast.Stmt{stmt} } case *ast.SelectStmt: var list []ast.Stmt for _, stmt := range stmt.Body.List { if v.visited(stmt) { list = append(list, stmt) } } stmt.Body.List = list return []ast.Stmt{stmt} case *ast.SwitchStmt: var list []ast.Stmt for _, stmt := range stmt.Body.List { if v.visitedAndMatters(stmt) { list = append(list, stmt) } } // If we didn't visit any case clauses, don't add the select at all. if len(list) == 0 { return nil } else { stmt.Body.List = list return []ast.Stmt{stmt} } case *ast.TypeSwitchStmt: var list []ast.Stmt for _, stmt := range stmt.Body.List { if v.visitedAndMatters(stmt) { list = append(list, stmt) } } // If we didn't visit any case clauses, don't add the select at all. if len(list) == 0 { return nil } else { stmt.Body.List = list return []ast.Stmt{stmt} } } }
func (v *StmtVisitor) VisitStmt(s ast.Stmt) { var statements *[]ast.Stmt switch s := s.(type) { case *ast.BlockStmt: statements = &s.List case *ast.CaseClause: statements = &s.Body case *ast.CommClause: statements = &s.Body case *ast.ForStmt: if s.Init != nil { v.VisitStmt(s.Init) } if s.Post != nil { v.VisitStmt(s.Post) } v.VisitStmt(s.Body) case *ast.IfStmt: if s.Init != nil { v.VisitStmt(s.Init) } v.VisitStmt(s.Body) if s.Else != nil { // Code copied from go.tools/cmd/cover, to deal with "if x {} else if y {}" const backupToElse = token.Pos(len("else ")) // The AST doesn't remember the else location. We can make an accurate guess. switch stmt := s.Else.(type) { case *ast.IfStmt: block := &ast.BlockStmt{ Lbrace: stmt.If - backupToElse, // So the covered part looks like it starts at the "else". List: []ast.Stmt{stmt}, Rbrace: stmt.End(), } s.Else = block case *ast.BlockStmt: stmt.Lbrace -= backupToElse // So the block looks like it starts at the "else". default: panic("unexpected node type in if") } v.VisitStmt(s.Else) } case *ast.LabeledStmt: v.VisitStmt(s.Stmt) case *ast.RangeStmt: v.VisitStmt(s.Body) case *ast.SelectStmt: v.VisitStmt(s.Body) case *ast.SwitchStmt: if s.Init != nil { v.VisitStmt(s.Init) } v.VisitStmt(s.Body) case *ast.TypeSwitchStmt: if s.Init != nil { v.VisitStmt(s.Init) } v.VisitStmt(s.Assign) v.VisitStmt(s.Body) } if statements == nil { return } for i := 0; i < len(*statements); i++ { s := (*statements)[i] switch s.(type) { case *ast.CaseClause, *ast.CommClause, *ast.BlockStmt: break default: start, end := v.fset.Position(s.Pos()), v.fset.Position(s.End()) se := &StmtExtent{ startOffset: start.Offset, startLine: start.Line, startCol: start.Column, endOffset: end.Offset, endLine: end.Line, endCol: end.Column, } v.function.stmts = append(v.function.stmts, se) } v.VisitStmt(s) } }