Example #1
0
// rewriteOptionalMethods makes specific mutations to marshaller methods that belong to types identified
// as being "optional" (they may be nil on the wire). This allows protobuf to serialize a map or slice and
// properly discriminate between empty and nil (which is not possible in protobuf).
// TODO: move into upstream gogo-protobuf once https://github.com/gogo/protobuf/issues/181
//   has agreement
func rewriteOptionalMethods(decl ast.Decl, isOptional OptionalFunc) {
	switch t := decl.(type) {
	case *ast.FuncDecl:
		ident, ptr, ok := receiver(t)
		if !ok {
			return
		}

		// correct initialization of the form `m.Field = &OptionalType{}` to
		// `m.Field = OptionalType{}`
		if t.Name.Name == "Unmarshal" {
			ast.Walk(optionalAssignmentVisitor{fn: isOptional}, t.Body)
		}

		if !isOptional(ident.Name) {
			return
		}

		switch t.Name.Name {
		case "Unmarshal":
			ast.Walk(&optionalItemsVisitor{}, t.Body)
		case "MarshalTo", "Size":
			ast.Walk(&optionalItemsVisitor{}, t.Body)
			fallthrough
		case "Marshal":
			// if the method has a pointer receiver, set it back to a normal receiver
			if ptr {
				t.Recv.List[0].Type = ident
			}
		}
	}
}
Example #2
0
// Visit walks the provided node, looking for specific patterns to transform that match
// the effective outcome of turning struct{ map[x]y || []x } into map[x]y or []x.
func (v *optionalItemsVisitor) Visit(n ast.Node) ast.Visitor {
	switch t := n.(type) {
	case *ast.RangeStmt:
		if isFieldSelector(t.X, "m", "Items") {
			t.X = &ast.Ident{Name: "m"}
		}
	case *ast.AssignStmt:
		if len(t.Lhs) == 1 && len(t.Rhs) == 1 {
			switch lhs := t.Lhs[0].(type) {
			case *ast.IndexExpr:
				if isFieldSelector(lhs.X, "m", "Items") {
					lhs.X = &ast.StarExpr{X: &ast.Ident{Name: "m"}}
				}
			default:
				if isFieldSelector(t.Lhs[0], "m", "Items") {
					t.Lhs[0] = &ast.StarExpr{X: &ast.Ident{Name: "m"}}
				}
			}
			switch rhs := t.Rhs[0].(type) {
			case *ast.CallExpr:
				if ident, ok := rhs.Fun.(*ast.Ident); ok && ident.Name == "append" {
					ast.Walk(v, rhs)
					if len(rhs.Args) > 0 {
						switch arg := rhs.Args[0].(type) {
						case *ast.Ident:
							if arg.Name == "m" {
								rhs.Args[0] = &ast.StarExpr{X: &ast.Ident{Name: "m"}}
							}
						}
					}
					return nil
				}
			}
		}
	case *ast.IfStmt:
		if b, ok := t.Cond.(*ast.BinaryExpr); ok && b.Op == token.EQL {
			if isFieldSelector(b.X, "m", "Items") && isIdent(b.Y, "nil") {
				b.X = &ast.StarExpr{X: &ast.Ident{Name: "m"}}
			}
		}
	case *ast.IndexExpr:
		if isFieldSelector(t.X, "m", "Items") {
			t.X = &ast.Ident{Name: "m"}
			return nil
		}
	case *ast.CallExpr:
		changed := false
		for i := range t.Args {
			if isFieldSelector(t.Args[i], "m", "Items") {
				t.Args[i] = &ast.Ident{Name: "m"}
				changed = true
			}
		}
		if changed {
			return nil
		}
	}
	return v
}
Example #3
0
// idents is an iterator that returns all idents in f via the result channel.
func idents(f *ast.File) <-chan *ast.Ident {
	v := make(visitor)
	go func() {
		ast.Walk(v, f)
		close(v)
	}()
	return v
}
Example #4
0
// Visit walks the provided node, looking for specific patterns to transform that match
// the effective outcome of turning struct{ map[x]y || []x } into map[x]y or []x.
func (v *optionalItemsVisitor) Visit(n ast.Node) ast.Visitor {
	switch t := n.(type) {
	case *ast.RangeStmt:
		if isFieldSelector(t.X, "m", "Items") {
			t.X = &ast.Ident{Name: "m"}
		}
	case *ast.AssignStmt:
		if len(t.Lhs) == 1 && len(t.Rhs) == 1 {
			switch lhs := t.Lhs[0].(type) {
			case *ast.IndexExpr:
				if isFieldSelector(lhs.X, "m", "Items") {
					lhs.X = &ast.StarExpr{X: &ast.Ident{Name: "m"}}
				}
			default:
				if isFieldSelector(t.Lhs[0], "m", "Items") {
					t.Lhs[0] = &ast.StarExpr{X: &ast.Ident{Name: "m"}}
				}
			}
			switch rhs := t.Rhs[0].(type) {
			case *ast.CallExpr:
				if ident, ok := rhs.Fun.(*ast.Ident); ok && ident.Name == "append" {
					ast.Walk(v, rhs)
					if len(rhs.Args) > 0 {
						switch arg := rhs.Args[0].(type) {
						case *ast.Ident:
							if arg.Name == "m" {
								rhs.Args[0] = &ast.StarExpr{X: &ast.Ident{Name: "m"}}
							}
						}
					}
					return nil
				}
			}
		}
	case *ast.IfStmt:
		switch cond := t.Cond.(type) {
		case *ast.BinaryExpr:
			if cond.Op == token.EQL {
				if isFieldSelector(cond.X, "m", "Items") && isIdent(cond.Y, "nil") {
					cond.X = &ast.StarExpr{X: &ast.Ident{Name: "m"}}
				}
			}
		}
		if t.Init != nil {
			// Find form:
			// if err := m[len(m.Items)-1].Unmarshal(data[iNdEx:postIndex]); err != nil {
			// 	return err
			// }
			switch s := t.Init.(type) {
			case *ast.AssignStmt:
				if call, ok := s.Rhs[0].(*ast.CallExpr); ok {
					if sel, ok := call.Fun.(*ast.SelectorExpr); ok {
						if x, ok := sel.X.(*ast.IndexExpr); ok {
							// m[] -> (*m)[]
							if sel2, ok := x.X.(*ast.SelectorExpr); ok {
								if ident, ok := sel2.X.(*ast.Ident); ok && ident.Name == "m" {
									x.X = &ast.StarExpr{X: &ast.Ident{Name: "m"}}
								}
							}
							// len(m.Items) -> len(*m)
							if bin, ok := x.Index.(*ast.BinaryExpr); ok {
								if call2, ok := bin.X.(*ast.CallExpr); ok && len(call2.Args) == 1 {
									if isFieldSelector(call2.Args[0], "m", "Items") {
										call2.Args[0] = &ast.StarExpr{X: &ast.Ident{Name: "m"}}
									}
								}
							}
						}
					}
				}
			}
		}
	case *ast.IndexExpr:
		if isFieldSelector(t.X, "m", "Items") {
			t.X = &ast.Ident{Name: "m"}
			return nil
		}
	case *ast.CallExpr:
		changed := false
		for i := range t.Args {
			if isFieldSelector(t.Args[i], "m", "Items") {
				t.Args[i] = &ast.Ident{Name: "m"}
				changed = true
			}
		}
		if changed {
			return nil
		}
	}
	return v
}