示例#1
0
func filterAPIs(service *pb.ServiceDescriptorProto, methods []*pb.MethodDescriptorProto, svcIndex int) []*API {
	var apis = make([]*API, 0, len(methods))
	path := fmt.Sprintf("6,%d", svcIndex) // 6 means service.

	var (
		descName  = "_" + service.GetName() + "_serviceDesc"
		methodIdx = 0
		streamIdx = 0
	)

	for i, method := range methods {
		stream := method.GetClientStreaming() || method.GetServerStreaming()
		index := 0
		if stream {
			index = streamIdx
		} else {
			index = methodIdx
		}

		v, _ := proto.GetExtension(method.Options, E_Http)
		info, _ := v.(*HttpRule)
		if info != nil {
			apis = append(apis, &API{
				service:       service,
				method:        method,
				desc:          info,
				descIndexPath: fmt.Sprintf("%s,2,%d", path, i), // 2 means method in a service.
				descName:      descName,
				stream:        stream,
				index:         index,
			})
		}

		if stream {
			streamIdx++
		} else {
			methodIdx++
		}
	}

	return apis
}
示例#2
0
// generateService generates all the code for the named service.
func (g *grpc) generateService(file *generator.FileDescriptor, service *pb.ServiceDescriptorProto, index int) {
	path := fmt.Sprintf("6,%d", index) // 6 means service.

	origServName := service.GetName()
	fullServName := origServName
	if pkg := file.GetPackage(); pkg != "" {
		fullServName = pkg + "." + fullServName
	}
	servName := generator.CamelCase(origServName)

	g.P()
	g.P("// Client API for ", servName, " service")
	g.P()

	// Client interface.
	g.P("type ", servName, "Client interface {")
	for i, method := range service.Method {
		g.gen.PrintComments(fmt.Sprintf("%s,2,%d", path, i)) // 2 means method in a service.
		g.P(g.generateClientSignature(servName, method))
	}
	g.P("}")
	g.P()

	// Client structure.
	g.P("type ", unexport(servName), "Client struct {")
	g.P("cc *", grpcPkg, ".ClientConn")
	g.P("}")
	g.P()

	// NewClient factory.
	g.P("func New", servName, "Client (cc *", grpcPkg, ".ClientConn) ", servName, "Client {")
	g.P("return &", unexport(servName), "Client{cc}")
	g.P("}")
	g.P()

	var methodIndex, streamIndex int
	serviceDescVar := "_" + servName + "_serviceDesc"
	// Client method implementations.
	for _, method := range service.Method {
		var descExpr string
		if !method.GetServerStreaming() && !method.GetClientStreaming() {
			// Unary RPC method
			descExpr = fmt.Sprintf("&%s.Methods[%d]", serviceDescVar, methodIndex)
			methodIndex++
		} else {
			// Streaming RPC method
			descExpr = fmt.Sprintf("&%s.Streams[%d]", serviceDescVar, streamIndex)
			streamIndex++
		}
		g.generateClientMethod(servName, fullServName, serviceDescVar, method, descExpr)
	}

	g.P("// Server API for ", servName, " service")
	g.P()

	// Server interface.
	serverType := servName + "Server"
	g.P("type ", serverType, " interface {")
	for i, method := range service.Method {
		g.gen.PrintComments(fmt.Sprintf("%s,2,%d", path, i)) // 2 means method in a service.
		g.P(g.generateServerSignature(servName, method))
	}
	g.P("}")
	g.P()

	// Server registration.
	g.P("func Register", servName, "Server(s *", grpcPkg, ".Server, srv ", serverType, ", options ...", gogogrpcPkg, ".ServerOption) {")
	g.P("s.RegisterService(", gogogrpcPkg, ".ApplyServerOptions(&", serviceDescVar, ", srv, options))")
	g.P("}")
	g.P()

	// Server handler implementations.
	var handlerNames []string
	for _, method := range service.Method {
		hname := g.generateServerMethod(servName, fullServName, method)
		handlerNames = append(handlerNames, hname)
	}

	// Service descriptor.
	g.P("var ", serviceDescVar, " = ", grpcPkg, ".ServiceDesc {")
	g.P("ServiceName: ", strconv.Quote(fullServName), ",")
	g.P("HandlerType: (*", serverType, ")(nil),")
	g.P("Methods: []", grpcPkg, ".MethodDesc{")
	for i, method := range service.Method {
		if method.GetServerStreaming() || method.GetClientStreaming() {
			continue
		}
		g.P("{")
		g.P("MethodName: ", strconv.Quote(method.GetName()), ",")
		g.P("Handler: ", handlerNames[i], ",")
		g.P("},")
	}
	g.P("},")
	g.P("Streams: []", grpcPkg, ".StreamDesc{")
	for i, method := range service.Method {
		if !method.GetServerStreaming() && !method.GetClientStreaming() {
			continue
		}
		g.P("{")
		g.P("StreamName: ", strconv.Quote(method.GetName()), ",")
		g.P("Handler: ", handlerNames[i], ",")
		if method.GetServerStreaming() {
			g.P("ServerStreams: true,")
		}
		if method.GetClientStreaming() {
			g.P("ClientStreams: true,")
		}
		g.P("},")
	}
	g.P("},")
	g.P("Metadata: ", file.VarName(), ",")
	g.P("}")
	g.P()
}
示例#3
0
func (g *jsonschema) generateServiceMethodSchema(file *generator.FileDescriptor, srv *pb.ServiceDescriptorProto, meth *pb.MethodDescriptorProto, rule *limbo.HttpRule, comment string) {
	var (
		method  string
		pattern string
	)

	switch p := rule.GetPattern().(type) {
	case *limbo.HttpRule_Delete:
		method = "DELETE"
		pattern = p.Delete
	case *limbo.HttpRule_Get:
		method = "GET"
		pattern = p.Get
	case *limbo.HttpRule_Post:
		method = "POST"
		pattern = p.Post
	case *limbo.HttpRule_Patch:
		method = "PATCH"
		pattern = p.Patch
	case *limbo.HttpRule_Put:
		method = "PUT"
		pattern = p.Put
	default:
		panic("unknown pattern type")
	}

	query := ""
	if idx := strings.IndexByte(pattern, '?'); idx >= 0 {
		query = pattern[idx+1:]
		pattern = pattern[:idx]
	}

	path := regexp.MustCompile("\\{.+\\}").ReplaceAllStringFunc(pattern, func(v string) string {
		return strings.Replace(v, ".", "_", -1)
	})

	input := strings.TrimPrefix(meth.GetInputType(), ".")
	output := strings.TrimPrefix(meth.GetOutputType(), ".")

	var outputSchema interface{} = map[string]string{"$ref": output}
	if meth.GetServerStreaming() {
		outputSchema = map[string]interface{}{
			"type":  "array",
			"items": outputSchema,
		}
	}

	var tags = []string{srv.GetName()}
	tags = append(tags, rule.Tags...)

	op := map[string]interface{}{
		"tags":        tags,
		"description": comment,
		"responses": map[string]interface{}{
			"200": map[string]interface{}{
				"description": "",
				"schema":      outputSchema,
			},
		},
	}

	var parameters []map[string]interface{}

	if params := g.collectPathParameters(pattern); len(params) > 0 {
		parameters = append(parameters, params...)
	}

	if params := g.collectQueryParameters(query); len(params) > 0 {
		parameters = append(parameters, params...)
	}

	if method != "HEAD" && method != "GET" && method != "OPTIONS" && method != "DELETE" {
		parameters = append(parameters, map[string]interface{}{
			"name": "parameters",
			"in":   "body",
			"schema": map[string]interface{}{
				"$ref": input,
			},
		})
	}

	if len(parameters) > 0 {
		op["parameters"] = parameters
	}

	if scope, ok := limbo.GetScope(meth); ok {
		op["security"] = []map[string][]string{
			{"oauth": {scope}},
		}
	}

	decl := &operationDecl{
		Pattern:      path,
		Method:       method,
		Dependencies: uniqStrings([]string{input, output}),
		Swagger:      op,
	}
	g.operations = append(g.operations, decl)

	// x, _ := json.MarshalIndent(op, "  ", "  ")
	// fmt.Fprintf(os.Stderr, "%s %s:\n  %s\n", method, path, x)

	for _, a := range rule.GetAlternatives() {
		g.generateServiceMethodSchema(file, srv, meth, a, comment)
	}
}
示例#4
0
// generateService generates all the code for the named service.
func (g *svcauth) generateService(file *generator.FileDescriptor, service *pb.ServiceDescriptorProto, index int) {
	methods := g.findMethods(file, service)
	if len(methods) == 0 {
		return
	}

	origServName := service.GetName()
	fullServName := file.GetPackage() + "." + origServName
	servName := generator.CamelCase(origServName)
	authDescVarName := "_" + servName + "_authDesc"

	methodsByName := make(map[*pb.MethodDescriptorProto]*authMethod)
	for _, m := range methods {
		methodsByName[m.method] = m
	}

	g.gen.AddInitf("%s.RegisterServiceAuthDesc(&%s)", g.runtimePkg.Use(), authDescVarName)

	var interfaceMethods []string

	g.P(`var `, authDescVarName, ` = `, g.runtimePkg.Use(), `.ServiceAuthDesc{`)
	g.P(`ServiceName: `, strconv.Quote(fullServName), `,`)
	g.P(`HandlerType: ((*`, servName, `Server)(nil)),`)
	g.P(`AuthHandlerType: ((*`, servName, `ServerAuth)(nil)),`)
	g.P(`Methods: []`, g.runtimePkg.Use(), `.MethodAuthDesc{`)
	for _, method := range service.Method {
		if method.GetServerStreaming() || method.GetClientStreaming() {
			continue
		}
		g.P(`{`)
		g.P(`MethodName: `, strconv.Quote(method.GetName()), `,`)
		g.generateDesc(servName, method, methodsByName[method], &interfaceMethods)
		g.P("},")
	}
	g.P("},")
	g.P(`Streams: []`, g.runtimePkg.Use(), `.StreamAuthDesc{`)
	for _, method := range service.Method {
		if !method.GetServerStreaming() && !method.GetClientStreaming() {
			continue
		}
		g.P(`{`)
		g.P(`StreamName: `, strconv.Quote(method.GetName()), `,`)
		g.generateDesc(servName, method, methodsByName[method], &interfaceMethods)
		g.P("},")
	}
	g.P("},")
	g.P("}")
	g.P()

	if len(interfaceMethods) > 0 {
		sort.Strings(interfaceMethods)
		last := ""
		g.P(`type `, servName, `ServerAuth interface {`)
		for _, sig := range interfaceMethods {
			if sig != last {
				last = sig
				g.P(sig)
			}
		}
		g.P(`}`)
	}
}
示例#5
0
func (g *svcauth) findMethods(file *generator.FileDescriptor, service *pb.ServiceDescriptorProto) []*authMethod {
	methods := make([]*authMethod, 0, len(service.Method))

	var (
		defaultAuthnInfo *AuthnRule
		defaultAuthzInfo *AuthzRule
	)

	if service.Options != nil {
		v, _ := proto.GetExtension(service.Options, E_DefaultAuthn)
		defaultAuthnInfo, _ = v.(*AuthnRule)
	}

	if service.Options != nil {
		v, _ := proto.GetExtension(service.Options, E_DefaultAuthz)
		defaultAuthzInfo, _ = v.(*AuthzRule)
	}

	for _, method := range service.Method {
		var (
			authnInfo *AuthnRule
			authzInfo *AuthzRule
		)

		{ // authn
			v, _ := proto.GetExtension(method.Options, E_Authn)
			authnInfo, _ = v.(*AuthnRule)
			if authnInfo == nil && defaultAuthnInfo != nil {
				authnInfo = &AuthnRule{}
			}
			if authnInfo != nil {
				authnInfo = defaultAuthnInfo.Inherit(authnInfo)
				authnInfo.SetDefaults()
			}
		}

		{ // authz
			v, _ := proto.GetExtension(method.Options, E_Authz)
			authzInfo, _ = v.(*AuthzRule)
			if authzInfo == nil && defaultAuthzInfo != nil {
				authzInfo = &AuthzRule{}
			}
			if authzInfo != nil {
				authzInfo = defaultAuthzInfo.Inherit(authzInfo)
				authzInfo.SetDefaults()
			}
		}

		if authnInfo == nil && authzInfo == nil {
			continue
		}

		methods = append(methods, &authMethod{
			file:    file,
			service: service,
			method:  method,
			Authn:   authnInfo,
			Authz:   authzInfo,
			Name:    fmt.Sprintf("/%s.%s/%s", file.GetPackage(), service.GetName(), method.GetName()),
		})
	}

	return methods
}
示例#6
0
// generateService generates all the code for the named service.
func (g *svchttp) generateService(file *generator.FileDescriptor, service *pb.ServiceDescriptorProto, index int) {
	apis := filterAPIs(service, service.Method, index)
	if len(apis) == 0 {
		return
	}

	origServName := service.GetName()
	fullServName := file.GetPackage() + "." + origServName
	servName := generator.CamelCase(origServName)
	gatewayVarName := "_" + servName + "_gatewayDesc"

	g.gen.AddInitf("%s.RegisterGatewayDesc(&%s)", g.runtimePkg.Use(), gatewayVarName)

	g.P(`var `, gatewayVarName, ` = `, g.runtimePkg.Use(), `.GatewayDesc{`)
	g.P(`ServiceName: `, strconv.Quote(fullServName), `,`)
	g.P(`HandlerType: ((*`, servName, `Server)(nil)),`)
	g.P(`Routes: []`, g.runtimePkg.Use(), `.RouteDesc{`)
	for _, api := range apis {
		_, method := api.desc, api.method

		httpMethod, pattern, ok := api.GetMethodAndPattern()
		if !ok {
			g.gen.Fail("xyz.featherhead.http requires a method: pattern")
		}

		if idx := strings.IndexRune(pattern, '?'); idx >= 0 {
			pattern = pattern[:idx]
		}

		g.P(`{`)
		g.P(`Method: `, strconv.Quote(httpMethod), `,`)
		g.P(`Pattern: `, strconv.Quote(pattern), `,`)
		g.P(`Handler: `, g.generateServerCallName(servName, method), `,`)
		g.P("},")
	}
	g.P("},")
	g.P("}")
	g.P()

	// Server handler implementations.
	for _, api := range apis {
		info, method := api.desc, api.method

		inputTypeName := method.GetInputType()
		inputType, _ := g.gen.ObjectNamed(inputTypeName).(*generator.Descriptor)

		httpMethod, pattern, ok := api.GetMethodAndPattern()
		queryParams := map[string]string{}
		if !ok {
			g.gen.Fail("xyz.featherhead.http requires a method: pattern")
		}

		if idx := strings.IndexRune(pattern, '?'); idx >= 0 {
			queryString := pattern[idx+1:]
			pattern = pattern[:idx]

			for _, pair := range strings.SplitN(queryString, "&", -1) {
				idx := strings.Index(pair, "={")
				if pair[len(pair)-1] != '}' || idx < 0 {
					g.gen.Fail("invalid query paramter")
				}
				queryParams[pair[:idx]] = pair[idx+2 : len(pair)-1]
			}
		}

		vars, err := router.ExtractVariables(pattern)
		if err != nil {
			g.gen.Error(err)
			return
		}

		var (
			httpResponseWriter = g.httpPkg.Use() + ".ResponseWriter"
			httpRequest        = g.httpPkg.Use() + ".Request"
			contextContext     = g.contextPkg.Use() + ".Context"
		)

		handlerMethod := g.generateServerCallName(servName, method)
		jujuErrors := g.jujuErrorsPkg.Use()
		g.P("func ", handlerMethod, "(srvDesc *", g.grpcPkg.Use(), ".ServiceDesc, srv interface{}, ctx ", contextContext, ", rw ", httpResponseWriter, ", req *", httpRequest, ") error {")
		g.P("if req.Method != ", strconv.Quote(httpMethod), "{")
		g.P(`  return `, jujuErrors, `.MethodNotAllowedf("expected `, httpMethod, ` request")`)
		g.P("}")
		g.P()

		if len(vars) > 0 {
			routerP := g.routerPkg.Use() + ".P"
			g.P(`params := `, routerP, `(ctx)`)
		}

		g.P(`stream, err := `, g.runtimePkg.Use(), `.NewServerStream(ctx, rw, req, `,
			method.GetServerStreaming(), `, `, method.GetClientStreaming(), `, `, int(info.PageSize), `, func(x interface{}) error {`)
		g.P(`input := x.(*`, g.typeName(inputTypeName), `)`)
		g.P(`_ = input`)
		g.P()

		for param, value := range queryParams {
			g.P("// populate ?", param, "=", value)
			g.generateHttpMapping(inputType, value, "req.URL.Query().Get("+strconv.Quote(param)+")")
		}

		for _, v := range vars {
			g.P("// populate {", v.Name, "}")
			g.generateHttpMapping(inputType, v.Name, "params.Get("+strconv.Quote(v.Name)+")")
		}

		g.P(`return nil`)
		g.P(`})`)
		g.P()

		if !api.stream {
			g.P(`desc := &srvDesc.Methods[`, api.index, `]`)
			g.P(`output, err := desc.Handler(srv, stream.Context(), stream.RecvMsg, nil)`)
			g.P(`if err == nil && output == nil {`)
			g.P(`err = `, g.grpcPkg.Use(), `.Errorf(`, g.grpcCodesPkg.Use(), `.Internal, "internal server error")`)
			g.P(`}`)
			g.P(`if err == nil {`)
			g.P(`err = stream.SendMsg(output)`)
			g.P(`}`)
		} else {
			g.P(`desc := &srvDesc.Streams[`, api.index, `]`)
			g.P(`err = desc.Handler(srv, stream)`)
		}
		g.P(`if err != nil {`)
		g.P(`stream.SetError(err)`)
		g.P(`}`)
		g.P()

		g.P(`return stream.CloseSend()`)
		g.P("}")
		g.P()

	}

}