Example #1
0
// ConstantValuePtr generates an expression which is a pointer to a value of
// type $t.
func ConstantValuePtr(g Generator, c compile.ConstantValue, t compile.TypeSpec) (string, error) {
	var ptrFunc string

	switch compile.RootTypeSpec(t).(type) {
	case *compile.BoolSpec:
		ptrFunc = fmt.Sprintf("%v.Bool", g.Import("go.uber.org/thriftrw/ptr"))
	case *compile.I8Spec:
		ptrFunc = fmt.Sprintf("%v.Int8", g.Import("go.uber.org/thriftrw/ptr"))
	case *compile.I16Spec:
		ptrFunc = fmt.Sprintf("%v.Int16", g.Import("go.uber.org/thriftrw/ptr"))
	case *compile.I32Spec:
		ptrFunc = fmt.Sprintf("%v.Int32", g.Import("go.uber.org/thriftrw/ptr"))
	case *compile.I64Spec:
		ptrFunc = fmt.Sprintf("%v.Int64", g.Import("go.uber.org/thriftrw/ptr"))
	case *compile.DoubleSpec:
		ptrFunc = fmt.Sprintf("%v.Float64", g.Import("go.uber.org/thriftrw/ptr"))
	case *compile.StringSpec:
		ptrFunc = fmt.Sprintf("%v.String", g.Import("go.uber.org/thriftrw/ptr"))
	case *compile.EnumSpec:
		ptrFunc = fmt.Sprintf("_%v_ptr", t.ThriftName())
		err := g.EnsureDeclared(
			`func _<.ThriftName>_ptr(v <typeReference .>) *<typeReference .> {
				return &v
			}`, t)
		if err != nil {
			return "", err
		}
	default:
		return ConstantValue(g, c, t) // not a primitive
	}

	s, err := ConstantValue(g, c, t)
	s = fmt.Sprintf("%v(%v)", ptrFunc, s)
	return s, err
}
Example #2
0
func valueName(spec compile.TypeSpec) string {
	switch s := spec.(type) {
	case *compile.MapSpec:
		return fmt.Sprintf(
			"Map_%s_%s", valueName(s.KeySpec), valueName(s.ValueSpec),
		)
	case *compile.ListSpec:
		return fmt.Sprintf("List_%s", valueName(s.ValueSpec))
	case *compile.SetSpec:
		return fmt.Sprintf("Set_%s", valueName(s.ValueSpec))
	default:
		return goCase(spec.ThriftName())
	}
}
Example #3
0
func (g *generator) LookupTypeName(t compile.TypeSpec) (string, error) {
	if t.ThriftFile() == "" {
		return "", fmt.Errorf(
			"LookupTypeName called with native type (%T) %v", t, t)
	}

	importPath, err := g.thriftImporter.Package(t.ThriftFile())
	if err != nil {
		return "", err
	}

	name, err := goName(t)
	if err != nil {
		return "", err
	}
	if importPath != g.ImportPath {
		pkg := g.Import(importPath)
		name = pkg + "." + name
	}
	return name, nil
}
Example #4
0
func TestGenerate(t *testing.T) {
	var (
		ts compile.TypeSpec = &compile.TypedefSpec{
			Name:   "Timestamp",
			File:   testdata(t, "thrift/common/bar.thrift"),
			Target: &compile.I64Spec{},
		}
		ts2 compile.TypeSpec = &compile.TypedefSpec{
			Name:   "Timestamp",
			File:   testdata(t, "thrift/foo.thrift"),
			Target: ts,
		}
	)

	ts2, err := ts2.Link(compile.EmptyScope("bar"))
	require.NoError(t, err)

	ts, err = ts.Link(compile.EmptyScope("bar"))
	require.NoError(t, err)

	module := &compile.Module{
		Name:       "foo",
		ThriftPath: testdata(t, "thrift/foo.thrift"),
		Includes: map[string]*compile.IncludedModule{
			"bar": {
				Name: "bar",
				Module: &compile.Module{
					Name:       "bar",
					ThriftPath: testdata(t, "thrift/common/bar.thrift"),
					Types:      map[string]compile.TypeSpec{"Timestamp": ts},
				},
			},
		},
		Types: map[string]compile.TypeSpec{"Timestamp": ts2},
	}

	tests := []struct {
		desc      string
		noRecurse bool
		getPlugin func(*gomock.Controller) plugin.Handle

		wantFiles []string
		wantError string
	}{
		{
			desc:      "nil plugin; no recurse",
			noRecurse: true,
			wantFiles: []string{"foo/types.go"},
		},
		{
			desc: "nil plugin; recurse",
			wantFiles: []string{
				"foo/types.go",
				"common/bar/types.go",
			},
		},
		{
			desc: "no service generator",
			getPlugin: func(mockCtrl *gomock.Controller) plugin.Handle {
				handle := handletest.NewMockHandle(mockCtrl)
				handle.EXPECT().ServiceGenerator().Return(nil)
				return handle
			},
			wantFiles: []string{
				"foo/types.go",
				"common/bar/types.go",
			},
		},
		{
			desc: "empty plugin",
			getPlugin: func(mockCtrl *gomock.Controller) plugin.Handle {
				return plugin.EmptyHandle
			},
			wantFiles: []string{
				"foo/types.go",
				"common/bar/types.go",
			},
		},
		{
			desc: "ServiceGenerator plugin",
			getPlugin: func(mockCtrl *gomock.Controller) plugin.Handle {
				sgen := handletest.NewMockServiceGenerator(mockCtrl)
				sgen.EXPECT().Generate(gomock.Any()).
					Return(&api.GenerateServiceResponse{
						Files: map[string][]byte{
							"foo.txt":    []byte("hello world\n"),
							"bar/baz.go": []byte("package bar\n"),
						},
					}, nil)

				handle := handletest.NewMockHandle(mockCtrl)
				handle.EXPECT().ServiceGenerator().Return(sgen)
				return handle
			},
			wantFiles: []string{
				"foo/types.go",
				"common/bar/types.go",
				"foo.txt",
				"bar/baz.go",
			},
		},
		{
			desc: "ServiceGenerator plugin conflict",
			getPlugin: func(mockCtrl *gomock.Controller) plugin.Handle {
				sgen := handletest.NewMockServiceGenerator(mockCtrl)
				sgen.EXPECT().Generate(gomock.Any()).
					Return(&api.GenerateServiceResponse{
						Files: map[string][]byte{
							"common/bar/types.go": []byte("hulk smash"),
						},
					}, nil)

				handle := handletest.NewMockHandle(mockCtrl)
				handle.EXPECT().ServiceGenerator().Return(sgen)
				return handle
			},
			wantError: `file generation conflict: multiple sources are trying to write to "common/bar/types.go"`,
		},
		{
			desc: "ServiceGenerator plugin error",
			getPlugin: func(mockCtrl *gomock.Controller) plugin.Handle {
				sgen := handletest.NewMockServiceGenerator(mockCtrl)
				sgen.EXPECT().Generate(gomock.Any()).Return(nil, errors.New("great sadness"))

				handle := handletest.NewMockHandle(mockCtrl)
				handle.EXPECT().ServiceGenerator().Return(sgen)
				return handle
			},
			wantError: `great sadness`,
		},
	}

	for _, tt := range tests {
		func() {
			mockCtrl := gomock.NewController(t)
			defer mockCtrl.Finish()

			outputDir, err := ioutil.TempDir(os.TempDir(), "test-generate-recurse")
			require.NoError(t, err)
			defer os.RemoveAll(outputDir)

			var p plugin.Handle
			if tt.getPlugin != nil {
				p = tt.getPlugin(mockCtrl)
			}

			err = Generate(module, &Options{
				OutputDir:     outputDir,
				PackagePrefix: "go.uber.org/thriftrw/gen/testdata",
				ThriftRoot:    testdata(t, "thrift"),
				Plugin:        p,
				NoRecurse:     tt.noRecurse,
			})
			if tt.wantError != "" {
				assert.Contains(t, err.Error(), tt.wantError)
				return
			}

			if assert.NoError(t, err, tt.desc) {
				for _, f := range tt.wantFiles {
					_, err = os.Stat(filepath.Join(outputDir, f))
					assert.NoError(t, err, tt.desc)
				}
			}
		}()
	}
}