예제 #1
0
// PGUrl returns a postgres connection url which connects to this server with
// the given user. Returns a connection string and a cleanup function which must
// be called after any connection created using the string has been closed.
//
// In order to connect securely using postgres, this method will create
// temporary on-disk copies of certain embedded security certificates. The
// certificates will be created as temporary files in the provided directory,
// and their filenames will have the provided prefix. The returned cleanup
// function will delete these temporary files.
func PGUrl(t util.Tester, ts *server.TestServer, user, tempDir, prefix string) (url.URL, func()) {
	host, port, err := net.SplitHostPort(ts.PGAddr())
	if err != nil {
		t.Fatal(err)
	}

	caPath := security.CACertPath(security.EmbeddedCertsDir)
	certPath := security.ClientCertPath(security.EmbeddedCertsDir, user)
	keyPath := security.ClientKeyPath(security.EmbeddedCertsDir, user)

	// Copy these assets to disk from embedded strings, so this test can
	// run from a standalone binary.
	tempCAPath, tempCACleanup := securitytest.RestrictedCopy(t, caPath, tempDir, "TestLogic_ca")
	tempCertPath, tempCertCleanup := securitytest.RestrictedCopy(t, certPath, tempDir, "TestLogic_cert")
	tempKeyPath, tempKeyCleanup := securitytest.RestrictedCopy(t, keyPath, tempDir, "TestLogic_key")

	return url.URL{
			Scheme: "postgres",
			User:   url.User(user),
			Host:   net.JoinHostPort(host, port),
			RawQuery: fmt.Sprintf("sslmode=verify-full&sslrootcert=%s&sslcert=%s&sslkey=%s",
				url.QueryEscape(tempCAPath),
				url.QueryEscape(tempCertPath),
				url.QueryEscape(tempKeyPath),
			),
		}, func() {
			tempCACleanup()
			tempCertCleanup()
			tempKeyCleanup()
		}
}
예제 #2
0
// PGUrl returns a postgres connection url which connects to this server with the given user, and a
// cleanup function which must be called after all connections created using the connection url have
// been closed.
//
// In order to connect securely using postgres, this method will create temporary on-disk copies of
// certain embedded security certificates. The certificates will be created in a new temporary
// directory. The returned cleanup function will delete this temporary directory.
func PGUrl(t testing.TB, ts *server.TestServer, user, prefix string) (url.URL, func()) {
	host, port, err := net.SplitHostPort(ts.PGAddr())
	if err != nil {
		t.Fatal(err)
	}

	tempDir, err := ioutil.TempDir("", prefix)
	if err != nil {
		t.Fatal(err)
	}

	caPath := security.CACertPath(security.EmbeddedCertsDir)
	certPath := security.ClientCertPath(security.EmbeddedCertsDir, user)
	keyPath := security.ClientKeyPath(security.EmbeddedCertsDir, user)

	// Copy these assets to disk from embedded strings, so this test can
	// run from a standalone binary.
	tempCAPath := securitytest.RestrictedCopy(t, caPath, tempDir, "ca")
	tempCertPath := securitytest.RestrictedCopy(t, certPath, tempDir, "cert")
	tempKeyPath := securitytest.RestrictedCopy(t, keyPath, tempDir, "key")
	options := url.Values{}
	options.Add("sslmode", "verify-full")
	options.Add("sslrootcert", tempCAPath)
	options.Add("sslcert", tempCertPath)
	options.Add("sslkey", tempKeyPath)

	return url.URL{
			Scheme:   "postgres",
			User:     url.User(user),
			Host:     net.JoinHostPort(host, port),
			RawQuery: options.Encode(),
		}, func() {
			if err := os.RemoveAll(tempDir); err != nil {
				// Not Fatal() because we might already be panicking.
				t.Error(err)
			}
		}
}
예제 #3
0
func newCLITest() cliTest {
	// Reset the client context for each test. We don't reset the
	// pointer (because they are tied into the flags), but instead
	// overwrite the existing struct's values.
	baseCtx.InitDefaults()
	cliCtx.InitCLIDefaults()

	osStderr = os.Stdout

	s, err := serverutils.StartServerRaw(base.TestServerArgs{})
	if err != nil {
		log.Fatalf(context.Background(), "Could not start server: %v", err)
	}

	tempDir, err := ioutil.TempDir("", "cli-test")
	if err != nil {
		log.Fatal(context.Background(), err)
	}

	// Copy these assets to disk from embedded strings, so this test can
	// run from a standalone binary.
	// Disable embedded certs, or the security library will try to load
	// our real files as embedded assets.
	security.ResetReadFileFn()

	assets := []string{
		filepath.Join(security.EmbeddedCertsDir, security.EmbeddedCACert),
		filepath.Join(security.EmbeddedCertsDir, security.EmbeddedCAKey),
		filepath.Join(security.EmbeddedCertsDir, security.EmbeddedNodeCert),
		filepath.Join(security.EmbeddedCertsDir, security.EmbeddedNodeKey),
		filepath.Join(security.EmbeddedCertsDir, security.EmbeddedRootCert),
		filepath.Join(security.EmbeddedCertsDir, security.EmbeddedRootKey),
	}

	for _, a := range assets {
		securitytest.RestrictedCopy(nil, a, tempDir, filepath.Base(a))
	}

	return cliTest{
		TestServer: s.(*server.TestServer),
		certsDir:   tempDir,
		cleanupFunc: func() {
			if err := os.RemoveAll(tempDir); err != nil {
				log.Fatal(context.Background(), err)
			}
		},
	}
}
예제 #4
0
func newCLITest() cliTest {
	// Reset the client context for each test. We don't reset the
	// pointer (because they are tied into the flags), but instead
	// overwrite the existing struct's values.
	context.InitDefaults()

	osStderr = os.Stdout

	s := &server.TestServer{}
	if err := s.Start(); err != nil {
		log.Fatalf("Could not start server: %v", err)
	}

	tempDir, err := ioutil.TempDir("", "cli-test")
	if err != nil {
		log.Fatal(err)
	}

	// Copy these assets to disk from embedded strings, so this test can
	// run from a standalone binary.
	// Disable embedded certs, or the security library will try to load
	// our real files as embedded assets.
	security.ResetReadFileFn()

	assets := []string{
		security.CACertPath(security.EmbeddedCertsDir),
		security.ClientCertPath(security.EmbeddedCertsDir, security.RootUser),
		security.ClientKeyPath(security.EmbeddedCertsDir, security.RootUser),
		security.ClientCertPath(security.EmbeddedCertsDir, security.NodeUser),
		security.ClientKeyPath(security.EmbeddedCertsDir, security.NodeUser),
	}

	cleanups := []func(){}
	for _, a := range assets {
		_, cleanupFn := securitytest.RestrictedCopy(nil, a, tempDir, filepath.Base(a))
		cleanups = append(cleanups, cleanupFn)
	}

	return cliTest{
		TestServer: s,
		certsDir:   tempDir,
		cleanupFunc: func() {
			for _, f := range cleanups {
				f()
			}
		},
	}
}
예제 #5
0
func TestPGWire(t *testing.T) {
	defer leaktest.AfterTest(t)

	certUser := server.TestUser
	certPath := security.ClientCertPath(security.EmbeddedCertsDir, certUser)
	keyPath := security.ClientKeyPath(security.EmbeddedCertsDir, certUser)

	tempDir, err := ioutil.TempDir("", "TestPGWire")
	if err != nil {
		t.Fatal(err)
	}

	defer func() {
		if err := os.RemoveAll(tempDir); err != nil {
			// Not Fatal() because we might already be panicking.
			t.Error(err)
		}
	}()

	// Copy these assets to disk from embedded strings, so this test can
	// run from a standalone binary.
	tempCertPath := securitytest.RestrictedCopy(t, certPath, tempDir, "cert")
	tempKeyPath := securitytest.RestrictedCopy(t, keyPath, tempDir, "key")

	for _, insecure := range [...]bool{true, false} {
		ctx := server.NewTestContext()
		ctx.Insecure = insecure
		s := setupTestServerWithContext(t, ctx)

		host, port, err := net.SplitHostPort(s.PGAddr())
		if err != nil {
			t.Fatal(err)
		}

		basePgUrl := url.URL{
			Scheme: "postgres",
			Host:   net.JoinHostPort(host, port),
		}
		if err := trivialQuery(basePgUrl); err != nil {
			if insecure {
				if err != pq.ErrSSLNotSupported {
					t.Error(err)
				}
			} else {
				if !testutils.IsError(err, "no client certificates in request") {
					t.Error(err)
				}
			}
		}

		{
			disablePgUrl := basePgUrl
			disablePgUrl.RawQuery = "sslmode=disable"
			err := trivialQuery(disablePgUrl)
			if insecure {
				if err != nil {
					t.Error(err)
				}
			} else {
				if !testutils.IsError(err, pgwire.ErrSSLRequired) {
					t.Error(err)
				}
			}
		}

		{
			requirePgUrlNoCert := basePgUrl
			requirePgUrlNoCert.RawQuery = "sslmode=require"
			err := trivialQuery(requirePgUrlNoCert)
			if insecure {
				if err != pq.ErrSSLNotSupported {
					t.Error(err)
				}
			} else {
				if !testutils.IsError(err, "no client certificates in request") {
					t.Error(err)
				}
			}
		}

		{
			for _, optUser := range []string{certUser, security.RootUser} {
				requirePgUrlWithCert := basePgUrl
				requirePgUrlWithCert.User = url.User(optUser)
				requirePgUrlWithCert.RawQuery = fmt.Sprintf("sslmode=require&sslcert=%s&sslkey=%s",
					url.QueryEscape(tempCertPath),
					url.QueryEscape(tempKeyPath),
				)
				err := trivialQuery(requirePgUrlWithCert)
				if insecure {
					if err != pq.ErrSSLNotSupported {
						t.Error(err)
					}
				} else {
					if optUser == certUser {
						if err != nil {
							t.Error(err)
						}
					} else {
						if !testutils.IsError(err, `requested user is \w+, but certificate is for \w+`) {
							t.Error(err)
						}
					}
				}
			}
		}

		cleanupTestServer(s)
	}
}
예제 #6
0
func TestPGWire(t *testing.T) {
	defer leaktest.AfterTest(t)()

	certPath := filepath.Join(security.EmbeddedCertsDir, security.EmbeddedTestUserCert)
	keyPath := filepath.Join(security.EmbeddedCertsDir, security.EmbeddedTestUserKey)

	tempDir, err := ioutil.TempDir("", "TestPGWire")
	if err != nil {
		t.Fatal(err)
	}

	defer func() {
		if err := os.RemoveAll(tempDir); err != nil {
			// Not Fatal() because we might already be panicking.
			t.Error(err)
		}
	}()

	// Copy these assets to disk from embedded strings, so this test can
	// run from a standalone binary.
	tempCertPath := securitytest.RestrictedCopy(t, certPath, tempDir, "cert")
	tempKeyPath := securitytest.RestrictedCopy(t, keyPath, tempDir, "key")

	for _, insecure := range [...]bool{true, false} {
		params, _ := createTestServerParams()
		params.Insecure = insecure
		s, _, _ := serverutils.StartServer(t, params)

		host, port, err := net.SplitHostPort(s.ServingAddr())
		if err != nil {
			t.Fatal(err)
		}

		pgBaseURL := url.URL{
			Scheme: "postgres",
			Host:   net.JoinHostPort(host, port),
		}
		if err := trivialQuery(pgBaseURL); err != nil {
			if insecure {
				if err != pq.ErrSSLNotSupported {
					t.Error(err)
				}
			} else {
				if !testutils.IsError(err, "no client certificates in request") {
					t.Error(err)
				}
			}
		}

		{
			pgDisableURL := pgBaseURL
			pgDisableURL.RawQuery = "sslmode=disable"
			err := trivialQuery(pgDisableURL)
			if insecure {
				if err != nil {
					t.Error(err)
				}
			} else {
				if !testutils.IsError(err, pgwire.ErrSSLRequired) {
					t.Error(err)
				}
			}
		}

		{
			pgNoCertRequireURL := pgBaseURL
			pgNoCertRequireURL.RawQuery = "sslmode=require"
			err := trivialQuery(pgNoCertRequireURL)
			if insecure {
				if err != pq.ErrSSLNotSupported {
					t.Error(err)
				}
			} else {
				if !testutils.IsError(err, "no client certificates in request") {
					t.Error(err)
				}
			}
		}

		{
			for _, optUser := range []string{server.TestUser, security.RootUser} {
				pgWithCertRequireURL := pgBaseURL
				pgWithCertRequireURL.User = url.User(optUser)
				pgWithCertRequireURL.RawQuery = fmt.Sprintf("sslmode=require&sslcert=%s&sslkey=%s",
					url.QueryEscape(tempCertPath),
					url.QueryEscape(tempKeyPath),
				)
				err := trivialQuery(pgWithCertRequireURL)
				if insecure {
					if err != pq.ErrSSLNotSupported {
						t.Error(err)
					}
				} else {
					if optUser == server.TestUser {
						if err != nil {
							t.Error(err)
						}
					} else {
						if !testutils.IsError(err, `requested user is \w+, but certificate is for \w+`) {
							t.Error(err)
						}
					}
				}
			}
		}

		s.Stopper().Stop()
	}
}