func (g *raftProxyGen) genProxyStruct(s *descriptor.ServiceDescriptorProto) { g.gen.P("type " + serviceTypeName(s) + " struct {") g.gen.P("\tlocal " + s.GetName() + "Server") g.gen.P("\tconnSelector raftselector.ConnProvider") g.gen.P("\tctxMods []func(context.Context)(context.Context, error)") g.gen.P("}") }
func (g *raftProxyGen) genProxyConstructor(s *descriptor.ServiceDescriptorProto) { g.gen.P("func NewRaftProxy" + s.GetName() + "Server(local " + s.GetName() + "Server, connSelector raftselector.ConnProvider, ctxMod func(context.Context)(context.Context, error)) " + s.GetName() + "Server {") g.gen.P(`redirectChecker := func(ctx context.Context)(context.Context, error) { s, ok := transport.StreamFromContext(ctx) if !ok { return ctx, grpc.Errorf(codes.InvalidArgument, "remote addr is not found in context") } addr := s.ServerTransport().RemoteAddr().String() md, ok := metadata.FromContext(ctx) if ok && len(md["redirect"]) != 0 { return ctx, grpc.Errorf(codes.ResourceExhausted, "more than one redirect to leader from: %s", md["redirect"]) } if !ok { md = metadata.New(map[string]string{}) } md["redirect"] = append(md["redirect"], addr) return metadata.NewContext(ctx, md), nil } mods := []func(context.Context)(context.Context, error){redirectChecker} mods = append(mods, ctxMod) `) g.gen.P("return &" + serviceTypeName(s) + `{ local: local, connSelector: connSelector, ctxMods: mods, }`) g.gen.P("}") }
func (g *authenticatedWrapperGen) genServerStreamingMethod(s *descriptor.ServiceDescriptorProto, m *descriptor.MethodDescriptorProto) { g.gen.P(sigPrefix(s, m) + "r *" + getInputTypeName(m) + ", stream " + s.GetName() + "_" + m.GetName() + "Server) error {") authIntf, err := proto.GetExtension(m.Options, plugin.E_TlsAuthorization) if err != nil { g.gen.P(` panic("no authorization information in protobuf")`) g.gen.P(`}`) return } auth := authIntf.(*plugin.TLSAuthorization) if auth.Insecure != nil && *auth.Insecure { if len(auth.Roles) != 0 { panic("Roles and Insecure cannot both be specified") } g.gen.P(` return p.local.` + m.GetName() + `(r, stream)`) g.gen.P(`}`) return } g.gen.P(` if err := p.authorize(stream.Context(),` + genRoles(auth) + `); err != nil { return err } return p.local.` + m.GetName() + `(r, stream)`) g.gen.P("}") }
func (g *raftProxyGen) genSimpleMethod(s *descriptor.ServiceDescriptorProto, m *descriptor.MethodDescriptorProto) { g.gen.P(sigPrefix(s, m) + "ctx context.Context, r *" + getInputTypeName(m) + ") (*" + getOutputTypeName(m) + ", error) {") g.gen.P(` conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { return p.local.` + m.GetName() + `(ctx, r) } return nil, err } modCtx, err := p.runCtxMods(ctx) if err != nil { return nil, err }`) g.gen.P(` resp, err := New` + s.GetName() + `Client(conn).` + m.GetName() + `(modCtx, r) if err != nil { if !strings.Contains(err.Error(), "is closing") && !strings.Contains(err.Error(), "the connection is unavailable") && !strings.Contains(err.Error(), "connection error") { return resp, err } conn, err := p.pollNewLeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { return p.local.` + m.GetName() + `(ctx, r) } return nil, err } return New` + s.GetName() + `Client(conn).` + m.GetName() + `(modCtx, r) }`) g.gen.P("return resp, err") g.gen.P("}") }
func (g *authenticatedWrapperGen) genAuthenticatedConstructor(s *descriptor.ServiceDescriptorProto) { g.gen.P("func NewAuthenticatedWrapper" + s.GetName() + "Server(local " + s.GetName() + "Server, authorize func(context.Context, []string) error)" + s.GetName() + "Server {") g.gen.P("return &" + serviceTypeName(s) + `{ local: local, authorize: authorize, }`) g.gen.P("}") }
func (g *raftProxyGen) genClientStreamingMethod(s *descriptor.ServiceDescriptorProto, m *descriptor.MethodDescriptorProto) { streamType := s.GetName() + "_" + m.GetName() + "Server" // Generate stream wrapper that returns a modified context g.genStreamWrapper(streamType) g.gen.P(sigPrefix(s, m) + "stream " + streamType + `) error { ctx := stream.Context() conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { ctx, err = p.runCtxMods(ctx, p.localCtxMods) if err != nil { return err } streamWrapper := ` + streamType + `Wrapper{ ` + streamType + `: stream, ctx: ctx, } return p.local.` + m.GetName() + `(streamWrapper) } return err } ctx, err = p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return err }`) g.gen.P("clientStream, err := New" + s.GetName() + "Client(conn)." + m.GetName() + "(ctx)") g.gen.P(` if err != nil { return err }`) g.gen.P(` for { msg, err := stream.Recv() if err == io.EOF { break } if err != nil { return err } if err := clientStream.Send(msg); err != nil { return err } } reply, err := clientStream.CloseAndRecv() if err != nil { return err } return stream.SendAndClose(reply)`) g.gen.P("}") }
func (g *raftProxyGen) genClientServerStreamingMethod(s *descriptor.ServiceDescriptorProto, m *descriptor.MethodDescriptorProto) { g.gen.P(sigPrefix(s, m) + "stream " + s.GetName() + "_" + m.GetName() + "Server) error {") g.gen.P(` ctx := stream.Context() conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { return p.local.` + m.GetName() + `(stream) } return err } ctx, err = p.runCtxMods(ctx) if err != nil { return err }`) g.gen.P("clientStream, err := New" + s.GetName() + "Client(conn)." + m.GetName() + "(ctx)") g.gen.P(` if err != nil { return err }`) g.gen.P(`errc := make(chan error, 1) go func() { msg, err := stream.Recv() if err == io.EOF { close(errc) return } if err != nil { errc <- err return } if err := clientStream.Send(msg); err != nil { errc <- err return } }()`) g.gen.P(` for { msg, err := clientStream.Recv() if err == io.EOF { break } if err != nil { return err } if err := stream.Send(msg); err != nil { return err } } clientStream.CloseSend() return <-errc`) g.gen.P("}") }
func (g *raftProxyGen) genServerStreamingMethod(s *descriptor.ServiceDescriptorProto, m *descriptor.MethodDescriptorProto) { streamType := s.GetName() + "_" + m.GetName() + "Server" g.genStreamWrapper(streamType) g.gen.P(sigPrefix(s, m) + "r *" + getInputTypeName(m) + ", stream " + streamType + `) error { ctx := stream.Context() conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { ctx, err = p.runCtxMods(ctx, p.localCtxMods) if err != nil { return err } streamWrapper := ` + streamType + `Wrapper{ ` + streamType + `: stream, ctx: ctx, } return p.local.` + m.GetName() + `(r, streamWrapper) } return err } ctx, err = p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return err }`) g.gen.P("clientStream, err := New" + s.GetName() + "Client(conn)." + m.GetName() + "(ctx, r)") g.gen.P(` if err != nil { return err }`) g.gen.P(` for { msg, err := clientStream.Recv() if err == io.EOF { break } if err != nil { return err } if err := stream.Send(msg); err != nil { return err } } return nil`) g.gen.P("}") }
func (g *raftProxyGen) genSimpleMethod(s *descriptor.ServiceDescriptorProto, m *descriptor.MethodDescriptorProto) { g.gen.P(sigPrefix(s, m) + "ctx context.Context, r *" + getInputTypeName(m) + ") (*" + getOutputTypeName(m) + ", error) {") g.gen.P(` conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { return p.local.` + m.GetName() + `(ctx, r) } return nil, err } ctx, err = p.runCtxMods(ctx) if err != nil { return nil, err }`) g.gen.P("return New" + s.GetName() + "Client(conn)." + m.GetName() + "(ctx, r)") g.gen.P("}") }
// generateService generates all the code for the named service. func (g *grpc) generateService(file *generator.FileDescriptor, service *pb.ServiceDescriptorProto, index int) { path := fmt.Sprintf("6,%d", index) // 6 means service. origServName := service.GetName() fullServName := origServName if pkg := file.GetPackage(); pkg != "" { fullServName = pkg + "." + fullServName } servName := generator.CamelCase(origServName) g.P() g.P("// Client API for ", servName, " service") g.P() // Client interface. g.P("type ", servName, "Client interface {") for i, method := range service.Method { g.gen.PrintComments(fmt.Sprintf("%s,2,%d", path, i)) // 2 means method in a service. g.P(g.generateClientSignature(servName, method)) } g.P("}") g.P() // Client structure. g.P("type ", unexport(servName), "Client struct {") g.P("cc *", grpcPkg, ".ClientConn") g.P("}") g.P() // NewClient factory. g.P("func New", servName, "Client (cc *", grpcPkg, ".ClientConn) ", servName, "Client {") g.P("return &", unexport(servName), "Client{cc}") g.P("}") g.P() var methodIndex, streamIndex int serviceDescVar := "_" + servName + "_serviceDesc" // Client method implementations. for _, method := range service.Method { var descExpr string if !method.GetServerStreaming() && !method.GetClientStreaming() { // Unary RPC method descExpr = fmt.Sprintf("&%s.Methods[%d]", serviceDescVar, methodIndex) methodIndex++ } else { // Streaming RPC method descExpr = fmt.Sprintf("&%s.Streams[%d]", serviceDescVar, streamIndex) streamIndex++ } g.generateClientMethod(servName, fullServName, serviceDescVar, method, descExpr) } g.P("// Server API for ", servName, " service") g.P() // Server interface. serverType := servName + "Server" g.P("type ", serverType, " interface {") for i, method := range service.Method { g.gen.PrintComments(fmt.Sprintf("%s,2,%d", path, i)) // 2 means method in a service. g.P(g.generateServerSignature(servName, method)) } g.P("}") g.P() // Server registration. g.P("func Register", servName, "Server(s *", grpcPkg, ".Server, srv ", serverType, ") {") g.P("s.RegisterService(&", serviceDescVar, `, srv)`) g.P("}") g.P() // Server handler implementations. var handlerNames []string for _, method := range service.Method { hname := g.generateServerMethod(servName, fullServName, method) handlerNames = append(handlerNames, hname) } // Service descriptor. g.P("var ", serviceDescVar, " = ", grpcPkg, ".ServiceDesc {") g.P("ServiceName: ", strconv.Quote(fullServName), ",") g.P("HandlerType: (*", serverType, ")(nil),") g.P("Methods: []", grpcPkg, ".MethodDesc{") for i, method := range service.Method { if method.GetServerStreaming() || method.GetClientStreaming() { continue } g.P("{") g.P("MethodName: ", strconv.Quote(method.GetName()), ",") g.P("Handler: ", handlerNames[i], ",") g.P("},") } g.P("},") g.P("Streams: []", grpcPkg, ".StreamDesc{") for i, method := range service.Method { if !method.GetServerStreaming() && !method.GetClientStreaming() { continue } g.P("{") g.P("StreamName: ", strconv.Quote(method.GetName()), ",") g.P("Handler: ", handlerNames[i], ",") if method.GetServerStreaming() { g.P("ServerStreams: true,") } if method.GetClientStreaming() { g.P("ClientStreams: true,") } g.P("},") } g.P("},") g.P("Metadata: \"", file.GetName(), "\",") g.P("}") g.P() }
func serviceTypeName(s *descriptor.ServiceDescriptorProto) string { return "raftProxy" + s.GetName() + "Server" }
func serviceTypeName(s *descriptor.ServiceDescriptorProto) string { return "authenticatedWrapper" + s.GetName() + "Server" }
func (g *authenticatedWrapperGen) genAuthenticatedStruct(s *descriptor.ServiceDescriptorProto) { g.gen.P("type " + serviceTypeName(s) + " struct {") g.gen.P(" local " + s.GetName() + "Server") g.gen.P(" authorize func(context.Context, []string) error") g.gen.P("}") }