예제 #1
0
func (g *raftProxyGen) genProxyStruct(s *descriptor.ServiceDescriptorProto) {
	g.gen.P("type " + serviceTypeName(s) + " struct {")
	g.gen.P("\tlocal " + s.GetName() + "Server")
	g.gen.P("\tconnSelector raftselector.ConnProvider")
	g.gen.P("\tctxMods []func(context.Context)(context.Context, error)")
	g.gen.P("}")
}
예제 #2
0
func (g *raftProxyGen) genProxyConstructor(s *descriptor.ServiceDescriptorProto) {
	g.gen.P("func NewRaftProxy" + s.GetName() + "Server(local " + s.GetName() + "Server, connSelector raftselector.ConnProvider, ctxMod func(context.Context)(context.Context, error)) " + s.GetName() + "Server {")
	g.gen.P(`redirectChecker := func(ctx context.Context)(context.Context, error) {
		s, ok := transport.StreamFromContext(ctx)
		if !ok {
			return ctx, grpc.Errorf(codes.InvalidArgument, "remote addr is not found in context")
		}
		addr := s.ServerTransport().RemoteAddr().String()
		md, ok := metadata.FromContext(ctx)
		if ok && len(md["redirect"]) != 0 {
			return ctx, grpc.Errorf(codes.ResourceExhausted, "more than one redirect to leader from: %s", md["redirect"])
		}
		if !ok {
			md = metadata.New(map[string]string{})
		}
		md["redirect"] = append(md["redirect"], addr)
		return metadata.NewContext(ctx, md), nil
	}
	mods := []func(context.Context)(context.Context, error){redirectChecker}
	mods = append(mods, ctxMod)
	`)
	g.gen.P("return &" + serviceTypeName(s) + `{
		local: local,
		connSelector: connSelector,
		ctxMods: mods,
	}`)
	g.gen.P("}")
}
예제 #3
0
func (g *authenticatedWrapperGen) genServerStreamingMethod(s *descriptor.ServiceDescriptorProto, m *descriptor.MethodDescriptorProto) {
	g.gen.P(sigPrefix(s, m) + "r *" + getInputTypeName(m) + ", stream " + s.GetName() + "_" + m.GetName() + "Server) error {")

	authIntf, err := proto.GetExtension(m.Options, plugin.E_TlsAuthorization)
	if err != nil {
		g.gen.P(`
	panic("no authorization information in protobuf")`)
		g.gen.P(`}`)
		return
	}

	auth := authIntf.(*plugin.TLSAuthorization)

	if auth.Insecure != nil && *auth.Insecure {
		if len(auth.Roles) != 0 {
			panic("Roles and Insecure cannot both be specified")
		}
		g.gen.P(`
	return p.local.` + m.GetName() + `(r, stream)`)
		g.gen.P(`}`)
		return
	}

	g.gen.P(`
	if err := p.authorize(stream.Context(),` + genRoles(auth) + `); err != nil {
		return err
	}
	return p.local.` + m.GetName() + `(r, stream)`)
	g.gen.P("}")
}
예제 #4
0
func (g *raftProxyGen) genSimpleMethod(s *descriptor.ServiceDescriptorProto, m *descriptor.MethodDescriptorProto) {
	g.gen.P(sigPrefix(s, m) + "ctx context.Context, r *" + getInputTypeName(m) + ") (*" + getOutputTypeName(m) + ", error) {")
	g.gen.P(`
	conn, err := p.connSelector.LeaderConn(ctx)
	if err != nil {
		if err == raftselector.ErrIsLeader {
			return p.local.` + m.GetName() + `(ctx, r)
		}
		return nil, err
	}
	modCtx, err := p.runCtxMods(ctx)
	if err != nil {
		return nil, err
	}`)
	g.gen.P(`
	resp, err := New` + s.GetName() + `Client(conn).` + m.GetName() + `(modCtx, r)
	if err != nil {
		if !strings.Contains(err.Error(), "is closing") && !strings.Contains(err.Error(), "the connection is unavailable") && !strings.Contains(err.Error(), "connection error") {
			return resp, err
		}
		conn, err := p.pollNewLeaderConn(ctx)
		if err != nil {
			if err == raftselector.ErrIsLeader {
				return p.local.` + m.GetName() + `(ctx, r)
			}
			return nil, err
		}
		return New` + s.GetName() + `Client(conn).` + m.GetName() + `(modCtx, r)
	}`)
	g.gen.P("return resp, err")
	g.gen.P("}")
}
예제 #5
0
func (g *authenticatedWrapperGen) genAuthenticatedConstructor(s *descriptor.ServiceDescriptorProto) {
	g.gen.P("func NewAuthenticatedWrapper" + s.GetName() + "Server(local " + s.GetName() + "Server, authorize func(context.Context, []string) error)" + s.GetName() + "Server {")
	g.gen.P("return &" + serviceTypeName(s) + `{
		local: local,
		authorize: authorize,
	}`)
	g.gen.P("}")
}
예제 #6
0
func (g *raftProxyGen) genClientStreamingMethod(s *descriptor.ServiceDescriptorProto, m *descriptor.MethodDescriptorProto) {
	streamType := s.GetName() + "_" + m.GetName() + "Server"

	// Generate stream wrapper that returns a modified context
	g.genStreamWrapper(streamType)

	g.gen.P(sigPrefix(s, m) + "stream " + streamType + `) error {
	ctx := stream.Context()
	conn, err := p.connSelector.LeaderConn(ctx)
	if err != nil {
		if err == raftselector.ErrIsLeader {
			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
			if err != nil {
				return err
			}
			streamWrapper := ` + streamType + `Wrapper{
				` + streamType + `: stream,
				ctx: ctx,
			}
			return p.local.` + m.GetName() + `(streamWrapper)
		}
		return err
	}
	ctx, err = p.runCtxMods(ctx, p.remoteCtxMods)
	if err != nil {
		return err
	}`)
	g.gen.P("clientStream, err := New" + s.GetName() + "Client(conn)." + m.GetName() + "(ctx)")
	g.gen.P(`
	if err != nil {
			return err
	}`)
	g.gen.P(`
	for {
		msg, err := stream.Recv()
		if err == io.EOF {
			break
		}
		if err != nil {
			return err
		}
		if err := clientStream.Send(msg); err != nil {
			return err
		}
	}

	reply, err := clientStream.CloseAndRecv()
	if err != nil {
		return err
	}

	return stream.SendAndClose(reply)`)
	g.gen.P("}")
}
예제 #7
0
func (g *raftProxyGen) genClientServerStreamingMethod(s *descriptor.ServiceDescriptorProto, m *descriptor.MethodDescriptorProto) {
	g.gen.P(sigPrefix(s, m) + "stream " + s.GetName() + "_" + m.GetName() + "Server) error {")
	g.gen.P(`
	ctx := stream.Context()
	conn, err := p.connSelector.LeaderConn(ctx)
	if err != nil {
		if err == raftselector.ErrIsLeader {
			return p.local.` + m.GetName() + `(stream)
		}
		return err
	}
	ctx, err = p.runCtxMods(ctx)
	if err != nil {
		return err
	}`)
	g.gen.P("clientStream, err := New" + s.GetName() + "Client(conn)." + m.GetName() + "(ctx)")
	g.gen.P(`
	if err != nil {
			return err
	}`)
	g.gen.P(`errc := make(chan error, 1)
	go func() {
		msg, err := stream.Recv()
		if err == io.EOF {
			close(errc)
			return
		}
		if err != nil {
			errc <- err
			return
		}
		if err := clientStream.Send(msg); err != nil {
			errc <- err
			return
		}
	}()`)
	g.gen.P(`
	for {
		msg, err := clientStream.Recv()
		if err == io.EOF {
			break
		}
		if err != nil {
			return err
		}
		if err := stream.Send(msg); err != nil {
			return err
		}
	}
	clientStream.CloseSend()
	return <-errc`)
	g.gen.P("}")
}
예제 #8
0
func (g *raftProxyGen) genServerStreamingMethod(s *descriptor.ServiceDescriptorProto, m *descriptor.MethodDescriptorProto) {
	streamType := s.GetName() + "_" + m.GetName() + "Server"

	g.genStreamWrapper(streamType)

	g.gen.P(sigPrefix(s, m) + "r *" + getInputTypeName(m) + ", stream " + streamType + `) error {
	ctx := stream.Context()
	conn, err := p.connSelector.LeaderConn(ctx)
	if err != nil {
		if err == raftselector.ErrIsLeader {
			ctx, err = p.runCtxMods(ctx, p.localCtxMods)
			if err != nil {
				return err
			}
			streamWrapper := ` + streamType + `Wrapper{
				` + streamType + `: stream,
				ctx: ctx,
			}
			return p.local.` + m.GetName() + `(r, streamWrapper)
		}
		return err
	}
	ctx, err = p.runCtxMods(ctx, p.remoteCtxMods)
	if err != nil {
		return err
	}`)
	g.gen.P("clientStream, err := New" + s.GetName() + "Client(conn)." + m.GetName() + "(ctx, r)")
	g.gen.P(`
	if err != nil {
			return err
	}`)
	g.gen.P(`
	for {
		msg, err := clientStream.Recv()
		if err == io.EOF {
			break
		}
		if err != nil {
			return err
		}
		if err := stream.Send(msg); err != nil {
			return err
		}
	}
	return nil`)
	g.gen.P("}")
}
예제 #9
0
func (g *raftProxyGen) genSimpleMethod(s *descriptor.ServiceDescriptorProto, m *descriptor.MethodDescriptorProto) {
	g.gen.P(sigPrefix(s, m) + "ctx context.Context, r *" + getInputTypeName(m) + ") (*" + getOutputTypeName(m) + ", error) {")
	g.gen.P(`
	conn, err := p.connSelector.LeaderConn(ctx)
	if err != nil {
		if err == raftselector.ErrIsLeader {
			return p.local.` + m.GetName() + `(ctx, r)
		}
		return nil, err
	}
	ctx, err = p.runCtxMods(ctx)
	if err != nil {
		return nil, err
	}`)
	g.gen.P("return New" + s.GetName() + "Client(conn)." + m.GetName() + "(ctx, r)")
	g.gen.P("}")
}
예제 #10
0
// generateService generates all the code for the named service.
func (g *grpc) generateService(file *generator.FileDescriptor, service *pb.ServiceDescriptorProto, index int) {
	path := fmt.Sprintf("6,%d", index) // 6 means service.

	origServName := service.GetName()
	fullServName := origServName
	if pkg := file.GetPackage(); pkg != "" {
		fullServName = pkg + "." + fullServName
	}
	servName := generator.CamelCase(origServName)

	g.P()
	g.P("// Client API for ", servName, " service")
	g.P()

	// Client interface.
	g.P("type ", servName, "Client interface {")
	for i, method := range service.Method {
		g.gen.PrintComments(fmt.Sprintf("%s,2,%d", path, i)) // 2 means method in a service.
		g.P(g.generateClientSignature(servName, method))
	}
	g.P("}")
	g.P()

	// Client structure.
	g.P("type ", unexport(servName), "Client struct {")
	g.P("cc *", grpcPkg, ".ClientConn")
	g.P("}")
	g.P()

	// NewClient factory.
	g.P("func New", servName, "Client (cc *", grpcPkg, ".ClientConn) ", servName, "Client {")
	g.P("return &", unexport(servName), "Client{cc}")
	g.P("}")
	g.P()

	var methodIndex, streamIndex int
	serviceDescVar := "_" + servName + "_serviceDesc"
	// Client method implementations.
	for _, method := range service.Method {
		var descExpr string
		if !method.GetServerStreaming() && !method.GetClientStreaming() {
			// Unary RPC method
			descExpr = fmt.Sprintf("&%s.Methods[%d]", serviceDescVar, methodIndex)
			methodIndex++
		} else {
			// Streaming RPC method
			descExpr = fmt.Sprintf("&%s.Streams[%d]", serviceDescVar, streamIndex)
			streamIndex++
		}
		g.generateClientMethod(servName, fullServName, serviceDescVar, method, descExpr)
	}

	g.P("// Server API for ", servName, " service")
	g.P()

	// Server interface.
	serverType := servName + "Server"
	g.P("type ", serverType, " interface {")
	for i, method := range service.Method {
		g.gen.PrintComments(fmt.Sprintf("%s,2,%d", path, i)) // 2 means method in a service.
		g.P(g.generateServerSignature(servName, method))
	}
	g.P("}")
	g.P()

	// Server registration.
	g.P("func Register", servName, "Server(s *", grpcPkg, ".Server, srv ", serverType, ") {")
	g.P("s.RegisterService(&", serviceDescVar, `, srv)`)
	g.P("}")
	g.P()

	// Server handler implementations.
	var handlerNames []string
	for _, method := range service.Method {
		hname := g.generateServerMethod(servName, fullServName, method)
		handlerNames = append(handlerNames, hname)
	}

	// Service descriptor.
	g.P("var ", serviceDescVar, " = ", grpcPkg, ".ServiceDesc {")
	g.P("ServiceName: ", strconv.Quote(fullServName), ",")
	g.P("HandlerType: (*", serverType, ")(nil),")
	g.P("Methods: []", grpcPkg, ".MethodDesc{")
	for i, method := range service.Method {
		if method.GetServerStreaming() || method.GetClientStreaming() {
			continue
		}
		g.P("{")
		g.P("MethodName: ", strconv.Quote(method.GetName()), ",")
		g.P("Handler: ", handlerNames[i], ",")
		g.P("},")
	}
	g.P("},")
	g.P("Streams: []", grpcPkg, ".StreamDesc{")
	for i, method := range service.Method {
		if !method.GetServerStreaming() && !method.GetClientStreaming() {
			continue
		}
		g.P("{")
		g.P("StreamName: ", strconv.Quote(method.GetName()), ",")
		g.P("Handler: ", handlerNames[i], ",")
		if method.GetServerStreaming() {
			g.P("ServerStreams: true,")
		}
		if method.GetClientStreaming() {
			g.P("ClientStreams: true,")
		}
		g.P("},")
	}
	g.P("},")
	g.P("Metadata: \"", file.GetName(), "\",")
	g.P("}")
	g.P()
}
예제 #11
0
func serviceTypeName(s *descriptor.ServiceDescriptorProto) string {
	return "raftProxy" + s.GetName() + "Server"
}
예제 #12
0
func serviceTypeName(s *descriptor.ServiceDescriptorProto) string {
	return "authenticatedWrapper" + s.GetName() + "Server"
}
예제 #13
0
func (g *authenticatedWrapperGen) genAuthenticatedStruct(s *descriptor.ServiceDescriptorProto) {
	g.gen.P("type " + serviceTypeName(s) + " struct {")
	g.gen.P("	local " + s.GetName() + "Server")
	g.gen.P("	authorize func(context.Context, []string) error")
	g.gen.P("}")
}