// fixAndAndAssign rewrites if(x && (y = z) ...) ... to if(x) { y = z; if(...) ... } func fixAndAndAssign(stmt *cc.Stmt) { changed := false clauses := splitExpr(stmt.Expr, cc.AndAnd) for i := len(clauses) - 1; i > 0; i-- { before, _ := extractSideEffects(clauses[i], sideNoAfter) if len(before) == 0 { continue } changed = true stmt.Body = &cc.Stmt{ Op: BlockNoBrace, Block: append(before, &cc.Stmt{ Op: cc.If, Expr: joinExpr(clauses[i:], cc.AndAnd), Body: stmt.Body, }), } clauses = clauses[:i] } if changed { stmt.Expr = joinExpr(clauses, cc.AndAnd) } }
func fixMemset(prog *cc.Prog, fn *cc.Decl, stmt *cc.Stmt) { x := stmt.Expr if len(x.List) != 3 || x.List[1].Op != cc.Number || x.List[1].Text != "0" { // fprintf(x.Span, "unsupported %v - nonzero", x) return } if x.List[2].Op == cc.SizeofExpr || x.List[2].Op == cc.SizeofType { obj, objType := objIndir(fn, x.List[0]) if !matchSize(fn, obj, objType, x.List[2]) { // fprintf(x.Span, "unsupported %v - wrong size", x) return } x.Op = cc.Eq x.Left = obj x.Right = zeroFor(objType) x.List = nil return } siz := x.List[2] var count *cc.Expr var objType *cc.Type if siz.Op == cc.Mul { count = siz.Left siz = siz.Right if siz.Op != cc.SizeofExpr && siz.Op != cc.SizeofType { // fprintf(x.Span, "unsupported %v - wrong array size", x) return } switch siz.Op { case cc.SizeofExpr: p := unparen(siz.Left) if p.Op != cc.Indir && p.Op != cc.Index || !sameType(p.Left.XType, x.List[0].XType) { // fprintf(x.Span, "unsupported %v - wrong size", x) } objType = fixGoTypesExpr(fn, x.List[0], nil) case cc.SizeofType: objType = fixGoTypesExpr(fn, x.List[0], nil) if !sameType(siz.Type, objType.Base) { // fprintf(x.Span, "unsupported %v - wrong size", x) } } } else { count = siz objType = fixGoTypesExpr(fn, x.List[0], nil) if !objType.Base.Is(Byte) && !objType.Base.Is(Uint8) { // fprintf(x.Span, "unsupported %v - wrong size form for non-byte type", x) return } } if objType == nil { fprintf(x.Span, "unsupported %v - lost type", x) return } // Found it. Replace with zeroing for loop. stmt.Op = cc.For stmt.Pre = &cc.Expr{ Op: cc.Eq, Left: &cc.Expr{ Op: cc.Name, Text: "i", XType: intType, }, Right: &cc.Expr{ Op: cc.Number, Text: "0", XType: intType, }, XType: boolType, } stmt.Expr = &cc.Expr{ Op: cc.Lt, Left: &cc.Expr{ Op: cc.Name, Text: "i", XType: intType, }, Right: count, XType: boolType, } stmt.Post = &cc.Expr{ Op: cc.PostInc, Left: &cc.Expr{ Op: cc.Name, Text: "i", XType: intType, }, XType: intType, } stmt.Body = &cc.Stmt{ Op: cc.Block, Block: []*cc.Stmt{ { Op: cc.StmtExpr, Expr: &cc.Expr{ Op: cc.Eq, Left: &cc.Expr{ Op: cc.Index, Left: x.List[0], Right: &cc.Expr{Op: cc.Name, Text: "i"}, }, Right: zeroFor(objType.Base), }, }, }, } return }
func rewriteStmt(stmt *cc.Stmt) { // TODO: Double-check stmt.Labels switch stmt.Op { case cc.ARGBEGIN: panic(fmt.Sprintf("unexpected ARGBEGIN")) case cc.Do: // Rewrite do { ... } while(x) // to for(;;) { ... if(!x) break } // Since rewriteStmt is called in a preorder traversal, // the recursion into the children will clean up x // in the if condition as needed. stmt.Op = cc.For x := stmt.Expr stmt.Expr = nil stmt.Body = forceBlock(stmt.Body) stmt.Body.Block = append(stmt.Body.Block, &cc.Stmt{ Op: cc.If, Expr: &cc.Expr{Op: cc.Not, Left: x}, Body: &cc.Stmt{Op: cc.Break}, }) case cc.While: stmt.Op = cc.For fallthrough case cc.For: before1, _ := extractSideEffects(stmt.Pre, sideStmt|sideNoAfter) before2, _ := extractSideEffects(stmt.Expr, sideNoAfter) if len(before2) > 0 { x := stmt.Expr stmt.Expr = nil stmt.Body = forceBlock(stmt.Body) top := &cc.Stmt{ Op: cc.If, Expr: &cc.Expr{Op: cc.Not, Left: x}, Body: &cc.Stmt{Op: cc.Break}, } stmt.Body.Block = append(append(before2, top), stmt.Body.Block...) } if len(before1) > 0 { old := copyStmt(stmt) stmt.Pre = nil stmt.Expr = nil stmt.Post = nil stmt.Body = nil stmt.Op = BlockNoBrace stmt.Block = append(before1, old) } before, after := extractSideEffects(stmt.Post, sideStmt) if len(before)+len(after) > 0 { all := append(append(before, &cc.Stmt{Op: cc.StmtExpr, Expr: stmt.Post}), after...) stmt.Post = &cc.Expr{Op: ExprBlock, Block: all} } case cc.If, cc.Return: if stmt.Op == cc.If && stmt.Else == nil { fixAndAndAssign(stmt) } before, _ := extractSideEffects(stmt.Expr, sideNoAfter) if len(before) > 0 { old := copyStmt(stmt) stmt.Expr = nil stmt.Body = nil stmt.Else = nil stmt.Op = BlockNoBrace stmt.Block = append(before, old) } case cc.StmtExpr: before, after := extractSideEffects(stmt.Expr, sideStmt) if len(before)+len(after) > 0 { old := copyStmt(stmt) stmt.Expr = nil stmt.Op = BlockNoBrace stmt.Block = append(append(before, old), after...) } case cc.Goto: // TODO: Figure out where the goto goes and maybe rewrite // to labeled break/continue. // Otherwise move code or something. case cc.Switch: // TODO: Change default fallthrough to default break. before, _ := extractSideEffects(stmt.Expr, sideNoAfter) if len(before) > 0 { old := copyStmt(stmt) stmt.Expr = nil stmt.Body = nil stmt.Else = nil stmt.Op = BlockNoBrace stmt.Block = append(before, old) break // recursion will rewrite new inner switch } rewriteSwitch(stmt) } }