Example #1
0
func TestServerRunWithSNI(t *testing.T) {
	tests := map[string]struct {
		Cert              TestCertSpec
		SNICerts          []NamedTestCertSpec
		ExpectedCertIndex int

		// passed in the client hello info, "localhost" if unset
		ServerName string

		// optional ip or hostname to pass to NewSelfClientConfig
		SelfClientBindAddressOverride string
		ExpectSelfClientError         bool
	}{
		"only one cert": {
			Cert: TestCertSpec{
				host: "localhost",
				ips:  []string{"127.0.0.1"},
			},
			ExpectedCertIndex: -1,
		},
		"cert with multiple alternate names": {
			Cert: TestCertSpec{
				host:  "localhost",
				names: []string{"test.com"},
				ips:   []string{"127.0.0.1"},
			},
			ExpectedCertIndex: -1,
			ServerName:        "test.com",
		},
		"one SNI and the default cert with the same name": {
			Cert: TestCertSpec{
				host: "localhost",
				ips:  []string{"127.0.0.1"},
			},
			SNICerts: []NamedTestCertSpec{
				{
					TestCertSpec: TestCertSpec{
						host: "localhost",
					},
				},
			},
			ExpectedCertIndex: 0,
		},
		"matching SNI cert": {
			Cert: TestCertSpec{
				host: "localhost",
				ips:  []string{"127.0.0.1"},
			},
			SNICerts: []NamedTestCertSpec{
				{
					TestCertSpec: TestCertSpec{
						host: "test.com",
					},
				},
			},
			ExpectedCertIndex: 0,
			ServerName:        "test.com",
		},
		"matching IP in SNI cert and the server cert": {
			// IPs must not be passed via SNI. Hence, the ServerName in the
			// HELLO packet is empty and the server should select the non-SNI cert.
			Cert: TestCertSpec{
				host: "localhost",
				ips:  []string{"10.0.0.1", "127.0.0.1"},
			},
			SNICerts: []NamedTestCertSpec{
				{
					TestCertSpec: TestCertSpec{
						host: "test.com",
						ips:  []string{"10.0.0.1"},
					},
				},
			},
			ExpectedCertIndex: -1,
			ServerName:        "10.0.0.1",
		},
		"wildcards": {
			Cert: TestCertSpec{
				host: "localhost",
				ips:  []string{"127.0.0.1"},
			},
			SNICerts: []NamedTestCertSpec{
				{
					TestCertSpec: TestCertSpec{
						host:  "test.com",
						names: []string{"*.test.com"},
					},
				},
			},
			ExpectedCertIndex: 0,
			ServerName:        "www.test.com",
		},

		"loopback: IP for loopback client on SNI cert": {
			Cert: TestCertSpec{
				host: "localhost",
			},
			SNICerts: []NamedTestCertSpec{
				{
					TestCertSpec: TestCertSpec{
						host: "test.com",
						ips:  []string{"127.0.0.1"},
					},
				},
			},
			ExpectedCertIndex:     -1,
			ExpectSelfClientError: true,
		},
		"loopback: IP for loopback client on server and SNI cert": {
			Cert: TestCertSpec{
				ips:  []string{"127.0.0.1"},
				host: "localhost",
			},
			SNICerts: []NamedTestCertSpec{
				{
					TestCertSpec: TestCertSpec{
						host: "test.com",
						ips:  []string{"127.0.0.1"},
					},
				},
			},
			ExpectedCertIndex: -1,
		},
		"loopback: bind to 0.0.0.0 => loopback uses localhost; localhost on server cert": {
			Cert: TestCertSpec{
				host: "localhost",
			},
			SNICerts: []NamedTestCertSpec{
				{
					TestCertSpec: TestCertSpec{
						host: "test.com",
					},
				},
			},
			ExpectedCertIndex:             -1,
			SelfClientBindAddressOverride: "0.0.0.0",
		},
		"loopback: bind to 0.0.0.0 => loopback uses localhost; localhost on SNI cert": {
			Cert: TestCertSpec{
				host: "test.com",
			},
			SNICerts: []NamedTestCertSpec{
				{
					TestCertSpec: TestCertSpec{
						host: "localhost",
					},
				},
			},
			ExpectedCertIndex:             0,
			SelfClientBindAddressOverride: "0.0.0.0",
		},
		"loopback: bind to 0.0.0.0 => loopback uses localhost; localhost on server and SNI cert": {
			Cert: TestCertSpec{
				host: "localhost",
			},
			SNICerts: []NamedTestCertSpec{
				{
					TestCertSpec: TestCertSpec{
						host: "localhost",
					},
				},
			},
			ExpectedCertIndex:             0,
			SelfClientBindAddressOverride: "0.0.0.0",
		},
	}

	tempDir, err := ioutil.TempDir("", "")
	if err != nil {
		t.Fatal(err)
	}
	defer os.RemoveAll(tempDir)

NextTest:
	for title, test := range tests {
		// create server cert
		serverCertBundleFile, serverKeyFile, err := createTestCertFiles(tempDir, test.Cert)
		if err != nil {
			t.Errorf("%q - failed to create server cert: %v", title, err)
			continue NextTest
		}
		ca, err := caCertFromBundle(serverCertBundleFile)
		if err != nil {
			t.Errorf("%q - failed to extract ca cert from server cert bundle: %v", title, err)
			continue NextTest
		}
		caCerts := []*x509.Certificate{ca}

		// create SNI certs
		var namedCertKeys []config.NamedCertKey
		serverSig, err := certFileSignature(serverCertBundleFile, serverKeyFile)
		if err != nil {
			t.Errorf("%q - failed to get server cert signature: %v", title, err)
			continue NextTest
		}
		signatures := map[string]int{
			serverSig: -1,
		}
		for j, c := range test.SNICerts {
			certBundleFile, keyFile, err := createTestCertFiles(tempDir, c.TestCertSpec)
			if err != nil {
				t.Errorf("%q - failed to create SNI cert %d: %v", title, j, err)
				continue NextTest
			}

			namedCertKeys = append(namedCertKeys, config.NamedCertKey{
				KeyFile:  keyFile,
				CertFile: certBundleFile,
				Names:    c.explicitNames,
			})

			ca, err := caCertFromBundle(certBundleFile)
			if err != nil {
				t.Errorf("%q - failed to extract ca cert from SNI cert %d: %v", title, j, err)
				continue NextTest
			}
			caCerts = append(caCerts, ca)

			// store index in namedCertKeys with the signature as the key
			sig, err := certFileSignature(certBundleFile, keyFile)
			if err != nil {
				t.Errorf("%q - failed get SNI cert %d signature: %v", title, j, err)
				continue NextTest
			}
			signatures[sig] = j
		}

		stopCh := make(chan struct{})

		// launch server
		etcdserver, config, _ := setUp(t)
		defer etcdserver.Terminate(t)

		v := fakeVersion()
		config.Version = &v

		config.EnableIndex = true
		_, err = config.ApplySecureServingOptions(&options.SecureServingOptions{
			ServingOptions: options.ServingOptions{
				BindAddress: net.ParseIP("127.0.0.1"),
				BindPort:    6443,
			},
			ServerCert: options.GeneratableKeyCert{
				CertKey: options.CertKey{
					CertFile: serverCertBundleFile,
					KeyFile:  serverKeyFile,
				},
			},
			SNICertKeys: namedCertKeys,
		})
		if err != nil {
			t.Errorf("%q - failed applying the SecureServingOptions: %v", title, err)
			continue NextTest
		}
		config.InsecureServingInfo = nil

		s, err := config.Complete().New()
		if err != nil {
			t.Errorf("%q - failed creating the server: %v", title, err)
			continue NextTest
		}

		// patch in a 0-port to enable auto port allocation
		s.SecureServingInfo.BindAddress = "127.0.0.1:0"

		if err := s.serveSecurely(stopCh); err != nil {
			t.Errorf("%q - failed running the server: %v", title, err)
			continue NextTest
		}

		// load ca certificates into a pool
		roots := x509.NewCertPool()
		for _, caCert := range caCerts {
			roots.AddCert(caCert)
		}

		// try to dial
		addr := fmt.Sprintf("localhost:%d", s.effectiveSecurePort)
		t.Logf("Dialing %s as %q", addr, test.ServerName)
		conn, err := tls.Dial("tcp", addr, &tls.Config{
			RootCAs:    roots,
			ServerName: test.ServerName, // used for SNI in the client HELLO packet
		})
		if err != nil {
			t.Errorf("%q - failed to connect: %v", title, err)
			continue NextTest
		}

		// check returned server certificate
		sig := x509CertSignature(conn.ConnectionState().PeerCertificates[0])
		gotCertIndex, found := signatures[sig]
		if !found {
			t.Errorf("%q - unknown signature returned from server: %s", title, sig)
		}
		if gotCertIndex != test.ExpectedCertIndex {
			t.Errorf("%q - expected cert index %d, got cert index %d", title, test.ExpectedCertIndex, gotCertIndex)
		}

		conn.Close()

		// check that the loopback client can connect
		host := "127.0.0.1"
		if len(test.SelfClientBindAddressOverride) != 0 {
			host = test.SelfClientBindAddressOverride
		}
		config.SecureServingInfo.ServingInfo.BindAddress = net.JoinHostPort(host, strconv.Itoa(s.effectiveSecurePort))
		cfg, err := config.SecureServingInfo.NewSelfClientConfig("some-token")
		if test.ExpectSelfClientError {
			if err == nil {
				t.Errorf("%q - expected error creating loopback client config", title)
			}
			continue NextTest
		}
		if err != nil {
			t.Errorf("%q - failed creating loopback client config: %v", title, err)
			continue NextTest
		}
		client, err := clientset.NewForConfig(cfg)
		if err != nil {
			t.Errorf("%q - failed to create loopback client: %v", title, err)
			continue NextTest
		}
		got, err := client.ServerVersion()
		if err != nil {
			t.Errorf("%q - failed to connect with loopback client: %v", title, err)
			continue NextTest
		}
		if expected := &v; !reflect.DeepEqual(got, expected) {
			t.Errorf("%q - loopback client didn't get correct version info: expected=%v got=%v", title, expected, got)
		}
	}
}
Example #2
0
func TestServerRunWithSNI(t *testing.T) {
	tests := []struct {
		Cert              TestCertSpec
		SNICerts          []NamedTestCertSpec
		ExpectedCertIndex int

		// passed in the client hello info, "localhost" if unset
		ServerName string
	}{
		{
			// only one cert
			Cert: TestCertSpec{
				host: "localhost",
			},
			ExpectedCertIndex: -1,
		},
		{
			// cert with multiple alternate names
			Cert: TestCertSpec{
				host:  "localhost",
				names: []string{"test.com"},
				ips:   []string{"127.0.0.1"},
			},
			ExpectedCertIndex: -1,
			ServerName:        "test.com",
		},
		{
			// one SNI and the default cert with the same name
			Cert: TestCertSpec{
				host: "localhost",
			},
			SNICerts: []NamedTestCertSpec{
				{
					TestCertSpec: TestCertSpec{
						host: "localhost",
					},
				},
			},
			ExpectedCertIndex: 0,
		},
		{
			// matching SNI cert
			Cert: TestCertSpec{
				host: "localhost",
			},
			SNICerts: []NamedTestCertSpec{
				{
					TestCertSpec: TestCertSpec{
						host: "test.com",
					},
				},
			},
			ExpectedCertIndex: 0,
			ServerName:        "test.com",
		},
		{
			// matching IP in SNI cert and the server cert. But IPs must not be
			// passed via SNI. Hence, the ServerName in the HELLO packet is empty
			// and the server should select the non-SNI cert.
			Cert: TestCertSpec{
				host: "localhost",
				ips:  []string{"10.0.0.1"},
			},
			SNICerts: []NamedTestCertSpec{
				{
					TestCertSpec: TestCertSpec{
						host: "test.com",
						ips:  []string{"10.0.0.1"},
					},
				},
			},
			ExpectedCertIndex: -1,
			ServerName:        "10.0.0.1",
		},
		{
			// wildcards
			Cert: TestCertSpec{
				host: "localhost",
			},
			SNICerts: []NamedTestCertSpec{
				{
					TestCertSpec: TestCertSpec{
						host:  "test.com",
						names: []string{"*.test.com"},
					},
				},
			},
			ExpectedCertIndex: 0,
			ServerName:        "www.test.com",
		},
	}

	tempDir, err := ioutil.TempDir("", "")
	if err != nil {
		t.Fatal(err)
	}
	defer os.RemoveAll(tempDir)

NextTest:
	for i, test := range tests {
		// create server cert
		serverCertBundleFile, serverKeyFile, err := createTestCertFiles(tempDir, test.Cert)
		if err != nil {
			t.Errorf("%d - failed to create server cert: %v", i, err)
			continue NextTest
		}
		ca, err := caCertFromBundle(serverCertBundleFile)
		if err != nil {
			t.Errorf("%d - failed to extract ca cert from server cert bundle: %v", i, err)
			continue NextTest
		}
		caCerts := []*x509.Certificate{ca}

		// create SNI certs
		var namedCertKeys []config.NamedCertKey
		serverSig, err := certFileSignature(serverCertBundleFile, serverKeyFile)
		if err != nil {
			t.Errorf("%d - failed to get server cert signature: %v", i, err)
			continue NextTest
		}
		signatures := map[string]int{
			serverSig: -1,
		}
		for j, c := range test.SNICerts {
			certBundleFile, keyFile, err := createTestCertFiles(tempDir, c.TestCertSpec)
			if err != nil {
				t.Errorf("%d - failed to create SNI cert %d: %v", i, j, err)
				continue NextTest
			}

			namedCertKeys = append(namedCertKeys, config.NamedCertKey{
				KeyFile:  keyFile,
				CertFile: certBundleFile,
				Names:    c.explicitNames,
			})

			ca, err := caCertFromBundle(certBundleFile)
			if err != nil {
				t.Errorf("%d - failed to extract ca cert from SNI cert %d: %v", i, j, err)
				continue NextTest
			}
			caCerts = append(caCerts, ca)

			// store index in namedCertKeys with the signature as the key
			sig, err := certFileSignature(certBundleFile, keyFile)
			if err != nil {
				t.Errorf("%d - failed get SNI cert %d signature: %v", i, j, err)
				continue NextTest
			}
			signatures[sig] = j
		}

		stopCh := make(chan struct{})

		// launch server
		etcdserver, config, _ := setUp(t)
		defer etcdserver.Terminate(t)

		config.EnableIndex = true
		_, err = config.ApplySecureServingOptions(&options.SecureServingOptions{
			ServingOptions: options.ServingOptions{
				BindAddress: net.ParseIP("127.0.0.1"),
				BindPort:    6443,
			},
			ServerCert: options.GeneratableKeyCert{
				CertKey: options.CertKey{
					CertFile: serverCertBundleFile,
					KeyFile:  serverKeyFile,
				},
			},
			SNICertKeys: namedCertKeys,
		})
		if err != nil {
			t.Errorf("%d - failed applying the SecureServingOptions: %v", i, err)
			continue NextTest
		}
		config.InsecureServingInfo = nil

		s, err := config.Complete().New()
		if err != nil {
			t.Errorf("%d - failed creating the server: %v", i, err)
			continue NextTest
		}

		// patch in a 0-port to enable auto port allocation
		s.SecureServingInfo.BindAddress = "127.0.0.1:0"

		if err := s.serveSecurely(stopCh); err != nil {
			t.Errorf("%d - failed running the server: %v", i, err)
			continue NextTest
		}

		// load ca certificates into a pool
		roots := x509.NewCertPool()
		for _, caCert := range caCerts {
			roots.AddCert(caCert)
		}

		// try to dial
		addr := fmt.Sprintf("localhost:%d", s.effectiveSecurePort)
		t.Logf("Dialing %s as %q", addr, test.ServerName)
		conn, err := tls.Dial("tcp", addr, &tls.Config{
			RootCAs:    roots,
			ServerName: test.ServerName, // used for SNI in the client HELLO packet
		})
		if err != nil {
			t.Errorf("%d - failed to connect: %v", i, err)
			continue NextTest
		}

		// check returned server certificate
		sig := x509CertSignature(conn.ConnectionState().PeerCertificates[0])
		gotCertIndex, found := signatures[sig]
		if !found {
			t.Errorf("%d - unknown signature returned from server: %s", i, sig)
		}
		if gotCertIndex != test.ExpectedCertIndex {
			t.Errorf("%d - expected cert index %d, got cert index %d", i, test.ExpectedCertIndex, gotCertIndex)
		}

		conn.Close()
	}
}