func newCLITest() cliTest { // Reset the client context for each test. We don't reset the // pointer (because they are tied into the flags), but instead // overwrite the existing struct's values. context.InitDefaults() osStderr = os.Stdout s := &server.TestServer{} if err := s.Start(); err != nil { log.Fatalf("Could not start server: %v", err) } tempDir, err := ioutil.TempDir("", "cli-test") if err != nil { log.Fatal(err) } // Copy these assets to disk from embedded strings, so this test can // run from a standalone binary. // Disable embedded certs, or the security library will try to load // our real files as embedded assets. security.ResetReadFileFn() assets := []string{ security.CACertPath(security.EmbeddedCertsDir), security.ClientCertPath(security.EmbeddedCertsDir, security.RootUser), security.ClientKeyPath(security.EmbeddedCertsDir, security.RootUser), security.ClientCertPath(security.EmbeddedCertsDir, security.NodeUser), security.ClientKeyPath(security.EmbeddedCertsDir, security.NodeUser), } cleanups := []func(){} for _, a := range assets { _, cleanupFn := securitytest.RestrictedCopy(nil, a, tempDir, filepath.Base(a)) cleanups = append(cleanups, cleanupFn) } return cliTest{ TestServer: s, certsDir: tempDir, cleanupFunc: func() { for _, f := range cleanups { f() } }, } }
// PGUrl returns a postgres connection url which connects to this server with // the given user. Returns a connection string and a cleanup function which must // be called after any connection created using the string has been closed. // // In order to connect securely using postgres, this method will create // temporary on-disk copies of certain embedded security certificates. The // certificates will be created as temporary files in the provided directory, // and their filenames will have the provided prefix. The returned cleanup // function will delete these temporary files. func PGUrl(t util.Tester, ts *server.TestServer, user, tempDir, prefix string) (url.URL, func()) { host, port, err := net.SplitHostPort(ts.PGAddr()) if err != nil { t.Fatal(err) } caPath := filepath.Join(security.EmbeddedCertsDir, "ca.crt") certPath := security.ClientCertPath(security.EmbeddedCertsDir, user) keyPath := security.ClientKeyPath(security.EmbeddedCertsDir, user) // Copy these assets to disk from embedded strings, so this test can // run from a standalone binary. tempCAPath, tempCACleanup := securitytest.TempRestrictedCopy(t, caPath, tempDir, "TestLogic_ca") tempCertPath, tempCertCleanup := securitytest.TempRestrictedCopy(t, certPath, tempDir, "TestLogic_cert") tempKeyPath, tempKeyCleanup := securitytest.TempRestrictedCopy(t, keyPath, tempDir, "TestLogic_key") return url.URL{ Scheme: "postgres", User: url.User(user), Host: net.JoinHostPort(host, port), RawQuery: fmt.Sprintf("sslmode=verify-full&sslrootcert=%s&sslcert=%s&sslkey=%s", url.QueryEscape(tempCAPath), url.QueryEscape(tempCertPath), url.QueryEscape(tempKeyPath), ), }, func() { tempCACleanup() tempCertCleanup() tempKeyCleanup() } }
func makeSQLClient() (*sql.DB, string) { sqlURL := connURL if len(connURL) == 0 { options := url.Values{} if context.Insecure { options.Add("sslmode", "disable") } else { options.Add("sslmode", "verify-full") options.Add("sslcert", security.ClientCertPath(context.Certs, connUser)) options.Add("sslkey", security.ClientKeyPath(context.Certs, connUser)) options.Add("sslrootcert", security.CACertPath(context.Certs)) } pgURL := url.URL{ Scheme: "postgresql", User: url.User(connUser), Host: net.JoinHostPort(connHost, connPGPort), Path: connDBName, RawQuery: options.Encode(), } sqlURL = pgURL.String() } db, err := sql.Open("postgres", sqlURL) if err != nil { panicf("failed to initialize SQL client: %s", err) } return db, sqlURL }
// PGURL returns the URL for the postgres endpoint. func (ctx *Context) PGURL(user string) *url.URL { // Try to convert path to an absolute path. Failing to do so return path // unchanged. absPath := func(path string) string { r, err := filepath.Abs(path) if err != nil { return path } return r } options := url.Values{} if ctx.Insecure { options.Add("sslmode", "disable") } else { options.Add("sslmode", "verify-full") options.Add("sslcert", absPath(security.ClientCertPath(ctx.Certs, user))) options.Add("sslkey", absPath(security.ClientKeyPath(ctx.Certs, user))) options.Add("sslrootcert", absPath(security.CACertPath(ctx.Certs))) } return &url.URL{ Scheme: "postgresql", User: url.User(user), Host: ctx.Addr, RawQuery: options.Encode(), } }
// PGUrl returns a URL string for the given node postgres server. func (l *LocalCluster) PGUrl(i int) string { certUser := security.RootUser options := url.Values{} options.Add("sslmode", "verify-full") options.Add("sslcert", security.ClientCertPath(l.CertsDir, certUser)) options.Add("sslkey", security.ClientKeyPath(l.CertsDir, certUser)) options.Add("sslrootcert", security.CACertPath(l.CertsDir)) pgURL := url.URL{ Scheme: "postgres", User: url.User(certUser), Host: l.Nodes[i].PGAddr().String(), RawQuery: options.Encode(), } return pgURL.String() }
func makeSQLClient() (*sql.DB, string) { // Use the sql administrator by default (root user). sqlURL := connURL if len(connURL) == 0 { sslOptions := "" if context.Insecure { sslOptions = "sslmode=disable" } else { sslOptions = fmt.Sprintf("sslmode=verify-full&sslcert=%s&sslkey=%s&sslrootcert=%s", security.ClientCertPath(context.Certs, connUser), security.ClientKeyPath(context.Certs, connUser), security.CACertPath(context.Certs)) } sqlURL = fmt.Sprintf("postgresql://%s@%s:%s/%s?%s", connUser, connHost, connPGPort, connDBName, sslOptions) } db, err := sql.Open("postgres", sqlURL) if err != nil { panicf("failed to initialize SQL client: %s", err) } return db, sqlURL }
// PGUrl returns a postgres connection url which connects to this server with the given user, and a // cleanup function which must be called after all connections created using the connection url have // been closed. // // In order to connect securely using postgres, this method will create temporary on-disk copies of // certain embedded security certificates. The certificates will be created in a new temporary // directory. The returned cleanup function will delete this temporary directory. func PGUrl(t testing.TB, ts *server.TestServer, user, prefix string) (url.URL, func()) { host, port, err := net.SplitHostPort(ts.PGAddr()) if err != nil { t.Fatal(err) } tempDir, err := ioutil.TempDir("", prefix) if err != nil { t.Fatal(err) } caPath := security.CACertPath(security.EmbeddedCertsDir) certPath := security.ClientCertPath(security.EmbeddedCertsDir, user) keyPath := security.ClientKeyPath(security.EmbeddedCertsDir, user) // Copy these assets to disk from embedded strings, so this test can // run from a standalone binary. tempCAPath := securitytest.RestrictedCopy(t, caPath, tempDir, "ca") tempCertPath := securitytest.RestrictedCopy(t, certPath, tempDir, "cert") tempKeyPath := securitytest.RestrictedCopy(t, keyPath, tempDir, "key") options := url.Values{} options.Add("sslmode", "verify-full") options.Add("sslrootcert", tempCAPath) options.Add("sslcert", tempCertPath) options.Add("sslkey", tempKeyPath) return url.URL{ Scheme: "postgres", User: url.User(user), Host: net.JoinHostPort(host, port), RawQuery: options.Encode(), }, func() { if err := os.RemoveAll(tempDir); err != nil { // Not Fatal() because we might already be panicking. t.Error(err) } } }
func TestPGWire(t *testing.T) { defer leaktest.AfterTest(t) certUser := server.TestUser certPath := security.ClientCertPath(security.EmbeddedCertsDir, certUser) keyPath := security.ClientKeyPath(security.EmbeddedCertsDir, certUser) tempDir, err := ioutil.TempDir("", "TestPGWire") if err != nil { t.Fatal(err) } defer func() { if err := os.RemoveAll(tempDir); err != nil { // Not Fatal() because we might already be panicking. t.Error(err) } }() // Copy these assets to disk from embedded strings, so this test can // run from a standalone binary. tempCertPath := securitytest.RestrictedCopy(t, certPath, tempDir, "cert") tempKeyPath := securitytest.RestrictedCopy(t, keyPath, tempDir, "key") for _, insecure := range [...]bool{true, false} { ctx := server.NewTestContext() ctx.Insecure = insecure s := setupTestServerWithContext(t, ctx) host, port, err := net.SplitHostPort(s.PGAddr()) if err != nil { t.Fatal(err) } basePgUrl := url.URL{ Scheme: "postgres", Host: net.JoinHostPort(host, port), } if err := trivialQuery(basePgUrl); err != nil { if insecure { if err != pq.ErrSSLNotSupported { t.Error(err) } } else { if !testutils.IsError(err, "no client certificates in request") { t.Error(err) } } } { disablePgUrl := basePgUrl disablePgUrl.RawQuery = "sslmode=disable" err := trivialQuery(disablePgUrl) if insecure { if err != nil { t.Error(err) } } else { if !testutils.IsError(err, pgwire.ErrSSLRequired) { t.Error(err) } } } { requirePgUrlNoCert := basePgUrl requirePgUrlNoCert.RawQuery = "sslmode=require" err := trivialQuery(requirePgUrlNoCert) if insecure { if err != pq.ErrSSLNotSupported { t.Error(err) } } else { if !testutils.IsError(err, "no client certificates in request") { t.Error(err) } } } { for _, optUser := range []string{certUser, security.RootUser} { requirePgUrlWithCert := basePgUrl requirePgUrlWithCert.User = url.User(optUser) requirePgUrlWithCert.RawQuery = fmt.Sprintf("sslmode=require&sslcert=%s&sslkey=%s", url.QueryEscape(tempCertPath), url.QueryEscape(tempKeyPath), ) err := trivialQuery(requirePgUrlWithCert) if insecure { if err != pq.ErrSSLNotSupported { t.Error(err) } } else { if optUser == certUser { if err != nil { t.Error(err) } } else { if !testutils.IsError(err, `requested user is \w+, but certificate is for \w+`) { t.Error(err) } } } } } cleanupTestServer(s) } }