// rewriteOptionalMethods makes specific mutations to marshaller methods that belong to types identified // as being "optional" (they may be nil on the wire). This allows protobuf to serialize a map or slice and // properly discriminate between empty and nil (which is not possible in protobuf). // TODO: move into upstream gogo-protobuf once https://github.com/gogo/protobuf/issues/181 // has agreement func rewriteOptionalMethods(decl ast.Decl, isOptional OptionalFunc) { switch t := decl.(type) { case *ast.FuncDecl: ident, ptr, ok := receiver(t) if !ok { return } // correct initialization of the form `m.Field = &OptionalType{}` to // `m.Field = OptionalType{}` if t.Name.Name == "Unmarshal" { ast.Walk(optionalAssignmentVisitor{fn: isOptional}, t.Body) } if !isOptional(ident.Name) { return } switch t.Name.Name { case "Unmarshal": ast.Walk(&optionalItemsVisitor{}, t.Body) case "MarshalTo", "Size": ast.Walk(&optionalItemsVisitor{}, t.Body) fallthrough case "Marshal": // if the method has a pointer receiver, set it back to a normal receiver if ptr { t.Recv.List[0].Type = ident } } } }
// Visit walks the provided node, looking for specific patterns to transform that match // the effective outcome of turning struct{ map[x]y || []x } into map[x]y or []x. func (v *optionalItemsVisitor) Visit(n ast.Node) ast.Visitor { switch t := n.(type) { case *ast.RangeStmt: if isFieldSelector(t.X, "m", "Items") { t.X = &ast.Ident{Name: "m"} } case *ast.AssignStmt: if len(t.Lhs) == 1 && len(t.Rhs) == 1 { switch lhs := t.Lhs[0].(type) { case *ast.IndexExpr: if isFieldSelector(lhs.X, "m", "Items") { lhs.X = &ast.StarExpr{X: &ast.Ident{Name: "m"}} } default: if isFieldSelector(t.Lhs[0], "m", "Items") { t.Lhs[0] = &ast.StarExpr{X: &ast.Ident{Name: "m"}} } } switch rhs := t.Rhs[0].(type) { case *ast.CallExpr: if ident, ok := rhs.Fun.(*ast.Ident); ok && ident.Name == "append" { ast.Walk(v, rhs) if len(rhs.Args) > 0 { switch arg := rhs.Args[0].(type) { case *ast.Ident: if arg.Name == "m" { rhs.Args[0] = &ast.StarExpr{X: &ast.Ident{Name: "m"}} } } } return nil } } } case *ast.IfStmt: if b, ok := t.Cond.(*ast.BinaryExpr); ok && b.Op == token.EQL { if isFieldSelector(b.X, "m", "Items") && isIdent(b.Y, "nil") { b.X = &ast.StarExpr{X: &ast.Ident{Name: "m"}} } } case *ast.IndexExpr: if isFieldSelector(t.X, "m", "Items") { t.X = &ast.Ident{Name: "m"} return nil } case *ast.CallExpr: changed := false for i := range t.Args { if isFieldSelector(t.Args[i], "m", "Items") { t.Args[i] = &ast.Ident{Name: "m"} changed = true } } if changed { return nil } } return v }
// idents is an iterator that returns all idents in f via the result channel. func idents(f *ast.File) <-chan *ast.Ident { v := make(visitor) go func() { ast.Walk(v, f) close(v) }() return v }
// Visit walks the provided node, looking for specific patterns to transform that match // the effective outcome of turning struct{ map[x]y || []x } into map[x]y or []x. func (v *optionalItemsVisitor) Visit(n ast.Node) ast.Visitor { switch t := n.(type) { case *ast.RangeStmt: if isFieldSelector(t.X, "m", "Items") { t.X = &ast.Ident{Name: "m"} } case *ast.AssignStmt: if len(t.Lhs) == 1 && len(t.Rhs) == 1 { switch lhs := t.Lhs[0].(type) { case *ast.IndexExpr: if isFieldSelector(lhs.X, "m", "Items") { lhs.X = &ast.StarExpr{X: &ast.Ident{Name: "m"}} } default: if isFieldSelector(t.Lhs[0], "m", "Items") { t.Lhs[0] = &ast.StarExpr{X: &ast.Ident{Name: "m"}} } } switch rhs := t.Rhs[0].(type) { case *ast.CallExpr: if ident, ok := rhs.Fun.(*ast.Ident); ok && ident.Name == "append" { ast.Walk(v, rhs) if len(rhs.Args) > 0 { switch arg := rhs.Args[0].(type) { case *ast.Ident: if arg.Name == "m" { rhs.Args[0] = &ast.StarExpr{X: &ast.Ident{Name: "m"}} } } } return nil } } } case *ast.IfStmt: switch cond := t.Cond.(type) { case *ast.BinaryExpr: if cond.Op == token.EQL { if isFieldSelector(cond.X, "m", "Items") && isIdent(cond.Y, "nil") { cond.X = &ast.StarExpr{X: &ast.Ident{Name: "m"}} } } } if t.Init != nil { // Find form: // if err := m[len(m.Items)-1].Unmarshal(data[iNdEx:postIndex]); err != nil { // return err // } switch s := t.Init.(type) { case *ast.AssignStmt: if call, ok := s.Rhs[0].(*ast.CallExpr); ok { if sel, ok := call.Fun.(*ast.SelectorExpr); ok { if x, ok := sel.X.(*ast.IndexExpr); ok { // m[] -> (*m)[] if sel2, ok := x.X.(*ast.SelectorExpr); ok { if ident, ok := sel2.X.(*ast.Ident); ok && ident.Name == "m" { x.X = &ast.StarExpr{X: &ast.Ident{Name: "m"}} } } // len(m.Items) -> len(*m) if bin, ok := x.Index.(*ast.BinaryExpr); ok { if call2, ok := bin.X.(*ast.CallExpr); ok && len(call2.Args) == 1 { if isFieldSelector(call2.Args[0], "m", "Items") { call2.Args[0] = &ast.StarExpr{X: &ast.Ident{Name: "m"}} } } } } } } } } case *ast.IndexExpr: if isFieldSelector(t.X, "m", "Items") { t.X = &ast.Ident{Name: "m"} return nil } case *ast.CallExpr: changed := false for i := range t.Args { if isFieldSelector(t.Args[i], "m", "Items") { t.Args[i] = &ast.Ident{Name: "m"} changed = true } } if changed { return nil } } return v }