Example #1
0
func newCLITest(t *testing.T, insecure bool) (cliTest, error) {
	c := cliTest{}

	certsDir, err := ioutil.TempDir("", "cli-test")
	if err != nil {
		return cliTest{}, err
	}
	c.certsDir = certsDir

	// Reset the client context for each test. We don't reset the
	// pointer (because they are tied into the flags), but instead
	// overwrite the existing struct's values.
	baseCfg.InitDefaults()
	cliCtx.InitCLIDefaults()

	s, err := serverutils.StartServerRaw(base.TestServerArgs{Insecure: insecure})
	if err != nil {
		return cliTest{}, err
	}
	c.TestServer = s.(*server.TestServer)

	if insecure {
		c.cleanupFunc = func() error { return nil }
	} else {
		// Copy these assets to disk from embedded strings, so this test can
		// run from a standalone binary.
		// Disable embedded certs, or the security library will try to load
		// our real files as embedded assets.
		security.ResetReadFileFn()

		assets := []string{
			filepath.Join(security.EmbeddedCertsDir, security.EmbeddedCACert),
			filepath.Join(security.EmbeddedCertsDir, security.EmbeddedCAKey),
			filepath.Join(security.EmbeddedCertsDir, security.EmbeddedNodeCert),
			filepath.Join(security.EmbeddedCertsDir, security.EmbeddedNodeKey),
			filepath.Join(security.EmbeddedCertsDir, security.EmbeddedRootCert),
			filepath.Join(security.EmbeddedCertsDir, security.EmbeddedRootKey),
		}

		for _, a := range assets {
			securitytest.RestrictedCopy(nil, a, certsDir, filepath.Base(a))
		}

		c.cleanupFunc = func() error {
			security.SetReadFileFn(securitytest.Asset)
			return os.RemoveAll(certsDir)
		}
	}

	// Ensure that CLI error messages are logged to stdout, where they
	// can be captured.
	osStderr = os.Stdout

	return c, nil
}
Example #2
0
// PGUrl returns a postgres connection url which connects to this server with the given user, and a
// cleanup function which must be called after all connections created using the connection url have
// been closed.
//
// In order to connect securely using postgres, this method will create temporary on-disk copies of
// certain embedded security certificates. The certificates will be created in a new temporary
// directory. The returned cleanup function will delete this temporary directory.
// Note that two calls to this function for the same `user` will generate different
// copies of the certificates, so the cleanup function must always be called.
//
// Args:
//  prefix: A prefix to be prepended to the temp file names generated, for debugging.
func PGUrl(t testing.TB, servingAddr, prefix string, user *url.Userinfo) (url.URL, func()) {
	host, port, err := net.SplitHostPort(servingAddr)
	if err != nil {
		t.Fatal(err)
	}

	tempDir, err := ioutil.TempDir("", prefix)
	if err != nil {
		t.Fatal(err)
	}

	caPath := filepath.Join(security.EmbeddedCertsDir, security.EmbeddedCACert)
	certPath := filepath.Join(security.EmbeddedCertsDir, fmt.Sprintf("%s.crt", user.Username()))
	keyPath := filepath.Join(security.EmbeddedCertsDir, fmt.Sprintf("%s.key", user.Username()))

	// Copy these assets to disk from embedded strings, so this test can
	// run from a standalone binary.
	tempCAPath := securitytest.RestrictedCopy(t, caPath, tempDir, "ca")
	tempCertPath := securitytest.RestrictedCopy(t, certPath, tempDir, "cert")
	tempKeyPath := securitytest.RestrictedCopy(t, keyPath, tempDir, "key")
	options := url.Values{}
	options.Add("sslmode", "verify-full")
	options.Add("sslrootcert", tempCAPath)
	options.Add("sslcert", tempCertPath)
	options.Add("sslkey", tempKeyPath)

	return url.URL{
			Scheme:   "postgres",
			User:     user,
			Host:     net.JoinHostPort(host, port),
			RawQuery: options.Encode(),
		}, func() {
			if err := os.RemoveAll(tempDir); err != nil {
				// Not Fatal() because we might already be panicking.
				t.Error(err)
			}
		}
}
Example #3
0
func newCLITest() cliTest {
	// Reset the client context for each test. We don't reset the
	// pointer (because they are tied into the flags), but instead
	// overwrite the existing struct's values.
	baseCfg.InitDefaults()
	cliCtx.InitCLIDefaults()

	osStderr = os.Stdout

	s, err := serverutils.StartServerRaw(base.TestServerArgs{})
	if err != nil {
		log.Fatalf(context.Background(), "Could not start server: %s", err)
	}

	tempDir, err := ioutil.TempDir("", "cli-test")
	if err != nil {
		log.Fatal(context.Background(), err)
	}

	// Copy these assets to disk from embedded strings, so this test can
	// run from a standalone binary.
	// Disable embedded certs, or the security library will try to load
	// our real files as embedded assets.
	security.ResetReadFileFn()

	assets := []string{
		filepath.Join(security.EmbeddedCertsDir, security.EmbeddedCACert),
		filepath.Join(security.EmbeddedCertsDir, security.EmbeddedCAKey),
		filepath.Join(security.EmbeddedCertsDir, security.EmbeddedNodeCert),
		filepath.Join(security.EmbeddedCertsDir, security.EmbeddedNodeKey),
		filepath.Join(security.EmbeddedCertsDir, security.EmbeddedRootCert),
		filepath.Join(security.EmbeddedCertsDir, security.EmbeddedRootKey),
	}

	for _, a := range assets {
		securitytest.RestrictedCopy(nil, a, tempDir, filepath.Base(a))
	}

	return cliTest{
		TestServer: s.(*server.TestServer),
		certsDir:   tempDir,
		cleanupFunc: func() {
			if err := os.RemoveAll(tempDir); err != nil {
				log.Fatal(context.Background(), err)
			}
		},
	}
}
Example #4
0
func TestPGWire(t *testing.T) {
	defer leaktest.AfterTest(t)()

	certPath := filepath.Join(security.EmbeddedCertsDir, security.EmbeddedTestUserCert)
	keyPath := filepath.Join(security.EmbeddedCertsDir, security.EmbeddedTestUserKey)

	tempDir, err := ioutil.TempDir("", "TestPGWire")
	if err != nil {
		t.Fatal(err)
	}

	defer func() {
		if err := os.RemoveAll(tempDir); err != nil {
			// Not Fatal() because we might already be panicking.
			t.Error(err)
		}
	}()

	// Copy these assets to disk from embedded strings, so this test can
	// run from a standalone binary.
	tempCertPath := securitytest.RestrictedCopy(t, certPath, tempDir, "cert")
	tempKeyPath := securitytest.RestrictedCopy(t, keyPath, tempDir, "key")

	for _, insecure := range [...]bool{true, false} {
		params, _ := createTestServerParams()
		params.Insecure = insecure
		s, _, _ := serverutils.StartServer(t, params)

		host, port, err := net.SplitHostPort(s.ServingAddr())
		if err != nil {
			t.Fatal(err)
		}

		pgBaseURL := url.URL{
			Scheme: "postgres",
			User:   url.User(security.RootUser),
			Host:   net.JoinHostPort(host, port),
		}

		if err := trivialQuery(pgBaseURL); err != nil {
			if insecure {
				if err != pq.ErrSSLNotSupported {
					t.Error(err)
				}
			} else {
				// No certificates provided in secure mode defaults to password
				// authentication. This is disallowed for security.RootUser.
				if !testutils.IsError(err, fmt.Sprintf("pq: user %s must authenticate using a client certificate", security.RootUser)) {
					t.Errorf("unexpected error: %v", err)
				}
			}
		}

		{
			pgDisableURL := pgBaseURL
			pgDisableURL.RawQuery = "sslmode=disable"
			err := trivialQuery(pgDisableURL)
			if insecure {
				if err != nil {
					t.Error(err)
				}
			} else {
				if !testutils.IsError(err, pgwire.ErrSSLRequired) {
					t.Error(err)
				}
			}
		}

		{
			pgNoCertRequireURL := pgBaseURL
			pgNoCertRequireURL.RawQuery = "sslmode=require"
			err := trivialQuery(pgNoCertRequireURL)
			if insecure {
				if err != pq.ErrSSLNotSupported {
					t.Error(err)
				}
			} else {
				if !testutils.IsError(err, fmt.Sprintf("pq: user %s must authenticate using a client certificate", security.RootUser)) {
					t.Errorf("unexpected error: %v", err)
				}
			}
		}

		{
			for _, optUser := range []string{server.TestUser, security.RootUser} {
				pgWithCertRequireURL := pgBaseURL
				pgWithCertRequireURL.User = url.User(optUser)
				pgWithCertRequireURL.RawQuery = fmt.Sprintf("sslmode=require&sslcert=%s&sslkey=%s",
					url.QueryEscape(tempCertPath),
					url.QueryEscape(tempKeyPath),
				)
				err := trivialQuery(pgWithCertRequireURL)
				if insecure {
					if err != pq.ErrSSLNotSupported {
						t.Error(err)
					}
				} else {
					if optUser == server.TestUser {
						// The user TestUser has not been created so authentication
						// will fail with a valid certificate.
						if !testutils.IsError(err, fmt.Sprintf("pq: user %s does not exist", server.TestUser)) {
							t.Errorf("unexpected error: %v", err)
						}
					} else {
						if !testutils.IsError(err, `requested user is \w+, but certificate is for \w+`) {
							t.Error(err)
						}
					}
				}
			}
		}

		s.Stopper().Stop()
	}
}