예제 #1
0
파일: main.go 프로젝트: jackc/tern
func (c *Config) Connect() (*pgx.Conn, error) {
	if c.SSHConnConfig.Host != "" {
		client, err := NewSSHClient(&c.SSHConnConfig)
		if err != nil {
			return nil, err
		}

		c.ConnConfig.Dial = client.Dial
	}

	// If sslmode was set in config file or cli argument, set it in the
	// environment so we can use pgx.ParseEnvLibpq to use pgx's built-in
	// functionality.
	switch c.SslMode {
	case "disable", "allow", "prefer", "require", "verify-ca", "verify-full":
		if err := os.Setenv("PGHOST", c.ConnConfig.Host); err != nil {
			return nil, err
		}
		if err := os.Setenv("PGSSLMODE", c.SslMode); err != nil {
			return nil, err
		}

		if cc, err := pgx.ParseEnvLibpq(); err == nil {
			c.ConnConfig.TLSConfig = cc.TLSConfig
			c.ConnConfig.UseFallbackTLS = cc.UseFallbackTLS
			c.ConnConfig.FallbackTLSConfig = cc.FallbackTLSConfig
		} else {
			return nil, err
		}
	}

	return pgx.Connect(c.ConnConfig)
}
예제 #2
0
// Connect connects to the database using env vars.
// After connect, it creates tables if missing.
func (p *DB) Connect() (err error) {
	cfg, err := pgx.ParseEnvLibpq()
	if err != nil {
		return
	}

	pool, err := pgx.NewConnPool(pgx.ConnPoolConfig{
		ConnConfig:     cfg,
		MaxConnections: 25,
	})
	if err != nil {
		return
	}

	c, err := pgx_stdlib.OpenFromConnPool(pool)
	if err != nil {
		return
	}

	p.conn, err = gorm.Open("postgres", c)
	if err != nil {
		return
	}

	d := p.conn.AutoMigrate(&secrets.Secret{}, &secrets.Key{})

	return d.Error
}
예제 #3
0
파일: main.go 프로젝트: jackc/go_db_bench
func extractConfig() (config pgx.ConnPoolConfig, err error) {
	config.ConnConfig, err = pgx.ParseEnvLibpq()
	if err != nil {
		return config, err
	}

	if config.Host == "" {
		config.Host = "localhost"
	}

	if config.User == "" {
		config.User = os.Getenv("USER")
	}

	if config.Database == "" {
		config.Database = "go_db_bench"
	}

	config.TLSConfig = nil
	config.UseFallbackTLS = false

	config.MaxConnections = 10

	return config, nil
}
예제 #4
0
func createConnPool() (*pgx.ConnPool, error) {
	var config pgx.ConnPoolConfig
	var err error
	config.ConnConfig, err = pgx.ParseEnvLibpq()
	if err != nil {
		return nil, err
	}

	if config.Host == "" {
		config.Host = "localhost"
	}

	if config.User == "" {
		config.User = os.Getenv("USER")
	}

	if config.Database == "" {
		config.Database = "pgxdata"
	}

	config.TLSConfig = nil
	config.UseFallbackTLS = false
	config.MaxConnections = 10

	return pgx.NewConnPool(config)
}
예제 #5
0
파일: db.go 프로젝트: jaittola/loca
func getConfig() pgx.ConnPoolConfig {
	var err interface{}
	var connPoolConfig pgx.ConnPoolConfig

	connPoolConfig.ConnConfig, err = pgx.ParseEnvLibpq()
	if err != nil {
		log.Fatalf("Postgresql connection information missing from the environment: %v", err)
	}
	connPoolConfig.MaxConnections = 10

	return connPoolConfig
}
예제 #6
0
파일: pgx.go 프로젝트: hlandau/degoutils
func NewPgxPool(url string) (*pgx.ConnPool, error) {
	dbcfg := pgx.ConnPoolConfig{
		MaxConnections: maxConnectionsFlag.Value(),
	}

	var err error
	if url == "" {
		dbcfg.ConnConfig, err = pgx.ParseEnvLibpq()
	} else if strings.HasPrefix(url, "postgresql://") {
		dbcfg.ConnConfig, err = pgx.ParseURI(url)
	} else {
		dbcfg.ConnConfig, err = pgx.ParseDSN(url)
	}

	if err != nil {
		return nil, err
	}

	return pgx.NewConnPool(dbcfg)
}
예제 #7
0
파일: main.go 프로젝트: jackc/tern
func LoadConfig() (*Config, error) {
	config := &Config{VersionTable: "schema_version"}
	if connConfig, err := pgx.ParseEnvLibpq(); err == nil {
		config.ConnConfig = connConfig
	} else {
		return nil, err
	}

	// Set default config path only if it exists
	if cliOptions.configPath == "" {
		if _, err := os.Stat("./tern.conf"); err == nil {
			cliOptions.configPath = "./tern.conf"
		}
	}

	if cliOptions.configPath != "" {
		err := appendConfigFromFile(config, cliOptions.configPath)
		if err != nil {
			return nil, err
		}
	}

	appendConfigFromCLIArgs(config)

	if config.SSHConnConfig.User == "" {
		user, err := user.Current()
		if err != nil {
			return nil, err
		}
		config.SSHConnConfig.User = user.Username
	}

	if config.SSHConnConfig.Port == "" {
		config.SSHConnConfig.Port = "ssh"
	}

	return config, nil
}
예제 #8
0
파일: generate.go 프로젝트: jackc/pgxdata
func generateCmd(cmd *cobra.Command, args []string) {
	if len(args) != 0 {
		fmt.Fprintln(os.Stderr, "generate does not take any arguments")
		os.Exit(1)
	}

	var c Config
	var err error
	c.Database, err = pgx.ParseEnvLibpq()
	if err != nil {
		fmt.Fprintln(os.Stderr, err)
		os.Exit(1)
	}

	_, err = toml.DecodeFile("config.toml", &c)
	if err != nil {
		fmt.Fprintln(os.Stderr, err)
		os.Exit(1)
	}

	conn, err := pgx.Connect(c.Database)
	if err != nil {
		fmt.Fprintln(os.Stderr, err)
		os.Exit(1)
	}

	err = inspectDatabase(conn, c.Tables)
	if err != nil {
		fmt.Fprintln(os.Stderr, err)
		os.Exit(1)
	}

	templates := loadTemplates()

	supportData := initData{
		PkgName: c.Package,
		Version: VERSION,
		BoxTypes: []boxType{
			{Name: "Bool", ValueType: "bool", FormatCode: "pgx.BinaryFormatCode"},
			{Name: "Int16", ValueType: "int16", FormatCode: "pgx.BinaryFormatCode"},
			{Name: "Int32", ValueType: "int32", FormatCode: "pgx.BinaryFormatCode"},
			{Name: "Int64", ValueType: "int64", FormatCode: "pgx.BinaryFormatCode"},
			{Name: "String", ValueType: "string", FormatCode: "pgx.TextFormatCode"},
			{Name: "Time", ValueType: "time.Time", FormatCode: "pgx.BinaryFormatCode"},
			{Name: "IPNet", ValueType: "net.IPNet", FormatCode: "pgx.BinaryFormatCode"},
		},
		IntBoxTypes: []intBoxType{
			{Name: "Int16", BitSize: 16},
			{Name: "Int32", BitSize: 32},
			{Name: "Int64", BitSize: 64},
		},
	}

	supportFiles := []struct {
		path string
		tmpl *template.Template
	}{
		{"pgxdata_attribute.go", templates.Lookup("attribute")},
		{"pgxdata_db.go", templates.Lookup("db")},
	}
	for _, f := range supportFiles {
		err := writeSupportFile(f.path, f.tmpl, supportData)
		if err != nil {
			fmt.Fprintln(os.Stderr, err)
			os.Exit(1)
		}
	}

	for _, t := range c.Tables {
		file, err := os.Create("pgxdata_" + goCaseToFileCase(t.StructName) + ".go")
		if err != nil {
			fmt.Fprintln(os.Stderr, err)
			os.Exit(1)
		}

		err = writeTableCrud(file, templates, c.Package, t)
		if err != nil {
			fmt.Fprintln(os.Stderr, err)
			os.Exit(1)
		}

		file.Close()
	}
}
예제 #9
0
파일: conn_test.go 프로젝트: yunhor/pgx
func TestParseEnvLibpq(t *testing.T) {
	pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME"}

	savedEnv := make(map[string]string)
	for _, n := range pgEnvvars {
		savedEnv[n] = os.Getenv(n)
	}
	defer func() {
		for k, v := range savedEnv {
			err := os.Setenv(k, v)
			if err != nil {
				t.Fatalf("Unable to restore environment: %v", err)
			}
		}
	}()

	tests := []struct {
		name    string
		envvars map[string]string
		config  pgx.ConnConfig
	}{
		{
			name:    "No environment",
			envvars: map[string]string{},
			config: pgx.ConnConfig{
				TLSConfig:         &tls.Config{InsecureSkipVerify: true},
				UseFallbackTLS:    true,
				FallbackTLSConfig: nil,
				RuntimeParams:     map[string]string{},
			},
		},
		{
			name: "Normal PG vars",
			envvars: map[string]string{
				"PGHOST":     "123.123.123.123",
				"PGPORT":     "7777",
				"PGDATABASE": "foo",
				"PGUSER":     "******",
				"PGPASSWORD": "******",
			},
			config: pgx.ConnConfig{
				Host:              "123.123.123.123",
				Port:              7777,
				Database:          "foo",
				User:              "******",
				Password:          "******",
				TLSConfig:         &tls.Config{InsecureSkipVerify: true},
				UseFallbackTLS:    true,
				FallbackTLSConfig: nil,
				RuntimeParams:     map[string]string{},
			},
		},
		{
			name: "application_name",
			envvars: map[string]string{
				"PGAPPNAME": "pgxtest",
			},
			config: pgx.ConnConfig{
				TLSConfig:         &tls.Config{InsecureSkipVerify: true},
				UseFallbackTLS:    true,
				FallbackTLSConfig: nil,
				RuntimeParams:     map[string]string{"application_name": "pgxtest"},
			},
		},
		{
			name: "sslmode=disable",
			envvars: map[string]string{
				"PGSSLMODE": "disable",
			},
			config: pgx.ConnConfig{
				TLSConfig:      nil,
				UseFallbackTLS: false,
				RuntimeParams:  map[string]string{},
			},
		},
		{
			name: "sslmode=allow",
			envvars: map[string]string{
				"PGSSLMODE": "allow",
			},
			config: pgx.ConnConfig{
				TLSConfig:         nil,
				UseFallbackTLS:    true,
				FallbackTLSConfig: &tls.Config{InsecureSkipVerify: true},
				RuntimeParams:     map[string]string{},
			},
		},
		{
			name: "sslmode=prefer",
			envvars: map[string]string{
				"PGSSLMODE": "prefer",
			},
			config: pgx.ConnConfig{
				TLSConfig:         &tls.Config{InsecureSkipVerify: true},
				UseFallbackTLS:    true,
				FallbackTLSConfig: nil,
				RuntimeParams:     map[string]string{},
			},
		},
		{
			name: "sslmode=require",
			envvars: map[string]string{
				"PGSSLMODE": "require",
			},
			config: pgx.ConnConfig{
				TLSConfig:      &tls.Config{},
				UseFallbackTLS: false,
				RuntimeParams:  map[string]string{},
			},
		},
		{
			name: "sslmode=verify-ca",
			envvars: map[string]string{
				"PGSSLMODE": "verify-ca",
			},
			config: pgx.ConnConfig{
				TLSConfig:      &tls.Config{},
				UseFallbackTLS: false,
				RuntimeParams:  map[string]string{},
			},
		},
		{
			name: "sslmode=verify-full",
			envvars: map[string]string{
				"PGSSLMODE": "verify-full",
			},
			config: pgx.ConnConfig{
				TLSConfig:      &tls.Config{},
				UseFallbackTLS: false,
				RuntimeParams:  map[string]string{},
			},
		},
		{
			name: "sslmode=verify-full with host",
			envvars: map[string]string{
				"PGHOST":    "pgx.example",
				"PGSSLMODE": "verify-full",
			},
			config: pgx.ConnConfig{
				Host: "pgx.example",
				TLSConfig: &tls.Config{
					ServerName: "pgx.example",
				},
				UseFallbackTLS: false,
				RuntimeParams:  map[string]string{},
			},
		},
	}

	for _, tt := range tests {
		for _, n := range pgEnvvars {
			err := os.Unsetenv(n)
			if err != nil {
				t.Fatalf("%s: Unable to clear environment: %v", tt.name, err)
			}
		}

		for k, v := range tt.envvars {
			err := os.Setenv(k, v)
			if err != nil {
				t.Fatalf("%s: Unable to set environment: %v", tt.name, err)
			}
		}

		config, err := pgx.ParseEnvLibpq()
		if err != nil {
			t.Errorf("%s: Unexpected error from pgx.ParseLibpq() => %v", tt.name, err)
			continue
		}

		if config.Host != tt.config.Host {
			t.Errorf("%s: expected Host to be %v got %v", tt.name, tt.config.Host, config.Host)
		}
		if config.Port != tt.config.Port {
			t.Errorf("%s: expected Port to be %v got %v", tt.name, tt.config.Port, config.Port)
		}
		if config.Port != tt.config.Port {
			t.Errorf("%s: expected Port to be %v got %v", tt.name, tt.config.Port, config.Port)
		}
		if config.User != tt.config.User {
			t.Errorf("%s: expected User to be %v got %v", tt.name, tt.config.User, config.User)
		}
		if config.Password != tt.config.Password {
			t.Errorf("%s: expected Password to be %v got %v", tt.name, tt.config.Password, config.Password)
		}

		if !reflect.DeepEqual(config.RuntimeParams, tt.config.RuntimeParams) {
			t.Errorf("%s: expected RuntimeParams to be %#v got %#v", tt.name, tt.config.RuntimeParams, config.RuntimeParams)
		}

		tlsTests := []struct {
			name     string
			expected *tls.Config
			actual   *tls.Config
		}{
			{
				name:     "TLSConfig",
				expected: tt.config.TLSConfig,
				actual:   config.TLSConfig,
			},
			{
				name:     "FallbackTLSConfig",
				expected: tt.config.FallbackTLSConfig,
				actual:   config.FallbackTLSConfig,
			},
		}
		for _, tlsTest := range tlsTests {
			name := tlsTest.name
			expected := tlsTest.expected
			actual := tlsTest.actual

			if expected == nil && actual != nil {
				t.Errorf("%s / %s: expected nil, but it was set", tt.name, name)
			} else if expected != nil && actual == nil {
				t.Errorf("%s / %s: expected to be set, but got nil", tt.name, name)
			} else if expected != nil && actual != nil {
				if actual.InsecureSkipVerify != expected.InsecureSkipVerify {
					t.Errorf("%s / %s: expected InsecureSkipVerify to be %v got %v", tt.name, name, expected.InsecureSkipVerify, actual.InsecureSkipVerify)
				}

				if actual.ServerName != expected.ServerName {
					t.Errorf("%s / %s: expected ServerName to be %v got %v", tt.name, name, expected.ServerName, actual.ServerName)
				}
			}
		}

		if config.UseFallbackTLS != tt.config.UseFallbackTLS {
			t.Errorf("%s: expected UseFallbackTLS to be %v got %v", tt.name, tt.config.UseFallbackTLS, config.UseFallbackTLS)
		}
	}
}