func filterAPIs(service *pb.ServiceDescriptorProto, methods []*pb.MethodDescriptorProto, svcIndex int) []*API { var apis = make([]*API, 0, len(methods)) path := fmt.Sprintf("6,%d", svcIndex) // 6 means service. var ( descName = "_" + service.GetName() + "_serviceDesc" methodIdx = 0 streamIdx = 0 ) for i, method := range methods { stream := method.GetClientStreaming() || method.GetServerStreaming() index := 0 if stream { index = streamIdx } else { index = methodIdx } v, _ := proto.GetExtension(method.Options, E_Http) info, _ := v.(*HttpRule) if info != nil { apis = append(apis, &API{ service: service, method: method, desc: info, descIndexPath: fmt.Sprintf("%s,2,%d", path, i), // 2 means method in a service. descName: descName, stream: stream, index: index, }) } if stream { streamIdx++ } else { methodIdx++ } } return apis }
// generateService generates all the code for the named service. func (g *grpc) generateService(file *generator.FileDescriptor, service *pb.ServiceDescriptorProto, index int) { path := fmt.Sprintf("6,%d", index) // 6 means service. origServName := service.GetName() fullServName := origServName if pkg := file.GetPackage(); pkg != "" { fullServName = pkg + "." + fullServName } servName := generator.CamelCase(origServName) g.P() g.P("// Client API for ", servName, " service") g.P() // Client interface. g.P("type ", servName, "Client interface {") for i, method := range service.Method { g.gen.PrintComments(fmt.Sprintf("%s,2,%d", path, i)) // 2 means method in a service. g.P(g.generateClientSignature(servName, method)) } g.P("}") g.P() // Client structure. g.P("type ", unexport(servName), "Client struct {") g.P("cc *", grpcPkg, ".ClientConn") g.P("}") g.P() // NewClient factory. g.P("func New", servName, "Client (cc *", grpcPkg, ".ClientConn) ", servName, "Client {") g.P("return &", unexport(servName), "Client{cc}") g.P("}") g.P() var methodIndex, streamIndex int serviceDescVar := "_" + servName + "_serviceDesc" // Client method implementations. for _, method := range service.Method { var descExpr string if !method.GetServerStreaming() && !method.GetClientStreaming() { // Unary RPC method descExpr = fmt.Sprintf("&%s.Methods[%d]", serviceDescVar, methodIndex) methodIndex++ } else { // Streaming RPC method descExpr = fmt.Sprintf("&%s.Streams[%d]", serviceDescVar, streamIndex) streamIndex++ } g.generateClientMethod(servName, fullServName, serviceDescVar, method, descExpr) } g.P("// Server API for ", servName, " service") g.P() // Server interface. serverType := servName + "Server" g.P("type ", serverType, " interface {") for i, method := range service.Method { g.gen.PrintComments(fmt.Sprintf("%s,2,%d", path, i)) // 2 means method in a service. g.P(g.generateServerSignature(servName, method)) } g.P("}") g.P() // Server registration. g.P("func Register", servName, "Server(s *", grpcPkg, ".Server, srv ", serverType, ", options ...", gogogrpcPkg, ".ServerOption) {") g.P("s.RegisterService(", gogogrpcPkg, ".ApplyServerOptions(&", serviceDescVar, ", srv, options))") g.P("}") g.P() // Server handler implementations. var handlerNames []string for _, method := range service.Method { hname := g.generateServerMethod(servName, fullServName, method) handlerNames = append(handlerNames, hname) } // Service descriptor. g.P("var ", serviceDescVar, " = ", grpcPkg, ".ServiceDesc {") g.P("ServiceName: ", strconv.Quote(fullServName), ",") g.P("HandlerType: (*", serverType, ")(nil),") g.P("Methods: []", grpcPkg, ".MethodDesc{") for i, method := range service.Method { if method.GetServerStreaming() || method.GetClientStreaming() { continue } g.P("{") g.P("MethodName: ", strconv.Quote(method.GetName()), ",") g.P("Handler: ", handlerNames[i], ",") g.P("},") } g.P("},") g.P("Streams: []", grpcPkg, ".StreamDesc{") for i, method := range service.Method { if !method.GetServerStreaming() && !method.GetClientStreaming() { continue } g.P("{") g.P("StreamName: ", strconv.Quote(method.GetName()), ",") g.P("Handler: ", handlerNames[i], ",") if method.GetServerStreaming() { g.P("ServerStreams: true,") } if method.GetClientStreaming() { g.P("ClientStreams: true,") } g.P("},") } g.P("},") g.P("Metadata: ", file.VarName(), ",") g.P("}") g.P() }
func (g *jsonschema) generateServiceMethodSchema(file *generator.FileDescriptor, srv *pb.ServiceDescriptorProto, meth *pb.MethodDescriptorProto, rule *limbo.HttpRule, comment string) { var ( method string pattern string ) switch p := rule.GetPattern().(type) { case *limbo.HttpRule_Delete: method = "DELETE" pattern = p.Delete case *limbo.HttpRule_Get: method = "GET" pattern = p.Get case *limbo.HttpRule_Post: method = "POST" pattern = p.Post case *limbo.HttpRule_Patch: method = "PATCH" pattern = p.Patch case *limbo.HttpRule_Put: method = "PUT" pattern = p.Put default: panic("unknown pattern type") } query := "" if idx := strings.IndexByte(pattern, '?'); idx >= 0 { query = pattern[idx+1:] pattern = pattern[:idx] } path := regexp.MustCompile("\\{.+\\}").ReplaceAllStringFunc(pattern, func(v string) string { return strings.Replace(v, ".", "_", -1) }) input := strings.TrimPrefix(meth.GetInputType(), ".") output := strings.TrimPrefix(meth.GetOutputType(), ".") var outputSchema interface{} = map[string]string{"$ref": output} if meth.GetServerStreaming() { outputSchema = map[string]interface{}{ "type": "array", "items": outputSchema, } } var tags = []string{srv.GetName()} tags = append(tags, rule.Tags...) op := map[string]interface{}{ "tags": tags, "description": comment, "responses": map[string]interface{}{ "200": map[string]interface{}{ "description": "", "schema": outputSchema, }, }, } var parameters []map[string]interface{} if params := g.collectPathParameters(pattern); len(params) > 0 { parameters = append(parameters, params...) } if params := g.collectQueryParameters(query); len(params) > 0 { parameters = append(parameters, params...) } if method != "HEAD" && method != "GET" && method != "OPTIONS" && method != "DELETE" { parameters = append(parameters, map[string]interface{}{ "name": "parameters", "in": "body", "schema": map[string]interface{}{ "$ref": input, }, }) } if len(parameters) > 0 { op["parameters"] = parameters } if scope, ok := limbo.GetScope(meth); ok { op["security"] = []map[string][]string{ {"oauth": {scope}}, } } decl := &operationDecl{ Pattern: path, Method: method, Dependencies: uniqStrings([]string{input, output}), Swagger: op, } g.operations = append(g.operations, decl) // x, _ := json.MarshalIndent(op, " ", " ") // fmt.Fprintf(os.Stderr, "%s %s:\n %s\n", method, path, x) for _, a := range rule.GetAlternatives() { g.generateServiceMethodSchema(file, srv, meth, a, comment) } }
// generateService generates all the code for the named service. func (g *svcauth) generateService(file *generator.FileDescriptor, service *pb.ServiceDescriptorProto, index int) { methods := g.findMethods(file, service) if len(methods) == 0 { return } origServName := service.GetName() fullServName := file.GetPackage() + "." + origServName servName := generator.CamelCase(origServName) authDescVarName := "_" + servName + "_authDesc" methodsByName := make(map[*pb.MethodDescriptorProto]*authMethod) for _, m := range methods { methodsByName[m.method] = m } g.gen.AddInitf("%s.RegisterServiceAuthDesc(&%s)", g.runtimePkg.Use(), authDescVarName) var interfaceMethods []string g.P(`var `, authDescVarName, ` = `, g.runtimePkg.Use(), `.ServiceAuthDesc{`) g.P(`ServiceName: `, strconv.Quote(fullServName), `,`) g.P(`HandlerType: ((*`, servName, `Server)(nil)),`) g.P(`AuthHandlerType: ((*`, servName, `ServerAuth)(nil)),`) g.P(`Methods: []`, g.runtimePkg.Use(), `.MethodAuthDesc{`) for _, method := range service.Method { if method.GetServerStreaming() || method.GetClientStreaming() { continue } g.P(`{`) g.P(`MethodName: `, strconv.Quote(method.GetName()), `,`) g.generateDesc(servName, method, methodsByName[method], &interfaceMethods) g.P("},") } g.P("},") g.P(`Streams: []`, g.runtimePkg.Use(), `.StreamAuthDesc{`) for _, method := range service.Method { if !method.GetServerStreaming() && !method.GetClientStreaming() { continue } g.P(`{`) g.P(`StreamName: `, strconv.Quote(method.GetName()), `,`) g.generateDesc(servName, method, methodsByName[method], &interfaceMethods) g.P("},") } g.P("},") g.P("}") g.P() if len(interfaceMethods) > 0 { sort.Strings(interfaceMethods) last := "" g.P(`type `, servName, `ServerAuth interface {`) for _, sig := range interfaceMethods { if sig != last { last = sig g.P(sig) } } g.P(`}`) } }
func (g *svcauth) findMethods(file *generator.FileDescriptor, service *pb.ServiceDescriptorProto) []*authMethod { methods := make([]*authMethod, 0, len(service.Method)) var ( defaultAuthnInfo *AuthnRule defaultAuthzInfo *AuthzRule ) if service.Options != nil { v, _ := proto.GetExtension(service.Options, E_DefaultAuthn) defaultAuthnInfo, _ = v.(*AuthnRule) } if service.Options != nil { v, _ := proto.GetExtension(service.Options, E_DefaultAuthz) defaultAuthzInfo, _ = v.(*AuthzRule) } for _, method := range service.Method { var ( authnInfo *AuthnRule authzInfo *AuthzRule ) { // authn v, _ := proto.GetExtension(method.Options, E_Authn) authnInfo, _ = v.(*AuthnRule) if authnInfo == nil && defaultAuthnInfo != nil { authnInfo = &AuthnRule{} } if authnInfo != nil { authnInfo = defaultAuthnInfo.Inherit(authnInfo) authnInfo.SetDefaults() } } { // authz v, _ := proto.GetExtension(method.Options, E_Authz) authzInfo, _ = v.(*AuthzRule) if authzInfo == nil && defaultAuthzInfo != nil { authzInfo = &AuthzRule{} } if authzInfo != nil { authzInfo = defaultAuthzInfo.Inherit(authzInfo) authzInfo.SetDefaults() } } if authnInfo == nil && authzInfo == nil { continue } methods = append(methods, &authMethod{ file: file, service: service, method: method, Authn: authnInfo, Authz: authzInfo, Name: fmt.Sprintf("/%s.%s/%s", file.GetPackage(), service.GetName(), method.GetName()), }) } return methods }
// generateService generates all the code for the named service. func (g *svchttp) generateService(file *generator.FileDescriptor, service *pb.ServiceDescriptorProto, index int) { apis := filterAPIs(service, service.Method, index) if len(apis) == 0 { return } origServName := service.GetName() fullServName := file.GetPackage() + "." + origServName servName := generator.CamelCase(origServName) gatewayVarName := "_" + servName + "_gatewayDesc" g.gen.AddInitf("%s.RegisterGatewayDesc(&%s)", g.runtimePkg.Use(), gatewayVarName) g.P(`var `, gatewayVarName, ` = `, g.runtimePkg.Use(), `.GatewayDesc{`) g.P(`ServiceName: `, strconv.Quote(fullServName), `,`) g.P(`HandlerType: ((*`, servName, `Server)(nil)),`) g.P(`Routes: []`, g.runtimePkg.Use(), `.RouteDesc{`) for _, api := range apis { _, method := api.desc, api.method httpMethod, pattern, ok := api.GetMethodAndPattern() if !ok { g.gen.Fail("xyz.featherhead.http requires a method: pattern") } if idx := strings.IndexRune(pattern, '?'); idx >= 0 { pattern = pattern[:idx] } g.P(`{`) g.P(`Method: `, strconv.Quote(httpMethod), `,`) g.P(`Pattern: `, strconv.Quote(pattern), `,`) g.P(`Handler: `, g.generateServerCallName(servName, method), `,`) g.P("},") } g.P("},") g.P("}") g.P() // Server handler implementations. for _, api := range apis { info, method := api.desc, api.method inputTypeName := method.GetInputType() inputType, _ := g.gen.ObjectNamed(inputTypeName).(*generator.Descriptor) httpMethod, pattern, ok := api.GetMethodAndPattern() queryParams := map[string]string{} if !ok { g.gen.Fail("xyz.featherhead.http requires a method: pattern") } if idx := strings.IndexRune(pattern, '?'); idx >= 0 { queryString := pattern[idx+1:] pattern = pattern[:idx] for _, pair := range strings.SplitN(queryString, "&", -1) { idx := strings.Index(pair, "={") if pair[len(pair)-1] != '}' || idx < 0 { g.gen.Fail("invalid query paramter") } queryParams[pair[:idx]] = pair[idx+2 : len(pair)-1] } } vars, err := router.ExtractVariables(pattern) if err != nil { g.gen.Error(err) return } var ( httpResponseWriter = g.httpPkg.Use() + ".ResponseWriter" httpRequest = g.httpPkg.Use() + ".Request" contextContext = g.contextPkg.Use() + ".Context" ) handlerMethod := g.generateServerCallName(servName, method) jujuErrors := g.jujuErrorsPkg.Use() g.P("func ", handlerMethod, "(srvDesc *", g.grpcPkg.Use(), ".ServiceDesc, srv interface{}, ctx ", contextContext, ", rw ", httpResponseWriter, ", req *", httpRequest, ") error {") g.P("if req.Method != ", strconv.Quote(httpMethod), "{") g.P(` return `, jujuErrors, `.MethodNotAllowedf("expected `, httpMethod, ` request")`) g.P("}") g.P() if len(vars) > 0 { routerP := g.routerPkg.Use() + ".P" g.P(`params := `, routerP, `(ctx)`) } g.P(`stream, err := `, g.runtimePkg.Use(), `.NewServerStream(ctx, rw, req, `, method.GetServerStreaming(), `, `, method.GetClientStreaming(), `, `, int(info.PageSize), `, func(x interface{}) error {`) g.P(`input := x.(*`, g.typeName(inputTypeName), `)`) g.P(`_ = input`) g.P() for param, value := range queryParams { g.P("// populate ?", param, "=", value) g.generateHttpMapping(inputType, value, "req.URL.Query().Get("+strconv.Quote(param)+")") } for _, v := range vars { g.P("// populate {", v.Name, "}") g.generateHttpMapping(inputType, v.Name, "params.Get("+strconv.Quote(v.Name)+")") } g.P(`return nil`) g.P(`})`) g.P() if !api.stream { g.P(`desc := &srvDesc.Methods[`, api.index, `]`) g.P(`output, err := desc.Handler(srv, stream.Context(), stream.RecvMsg, nil)`) g.P(`if err == nil && output == nil {`) g.P(`err = `, g.grpcPkg.Use(), `.Errorf(`, g.grpcCodesPkg.Use(), `.Internal, "internal server error")`) g.P(`}`) g.P(`if err == nil {`) g.P(`err = stream.SendMsg(output)`) g.P(`}`) } else { g.P(`desc := &srvDesc.Streams[`, api.index, `]`) g.P(`err = desc.Handler(srv, stream)`) } g.P(`if err != nil {`) g.P(`stream.SetError(err)`) g.P(`}`) g.P() g.P(`return stream.CloseSend()`) g.P("}") g.P() } }