func check(t *testing.T, clientName, serverName string, shouldPass bool) { clientConfig, err := ttls.ConfigureClient(clientName, "test-tls-ca") if err != nil { t.Errorf("Bad client key '%s': %s", clientName, err) return } serverConfig, err := ttls.ConfigureServer(serverName, "test-tls-ca") if err != nil { t.Errorf("Bad server key '%s': %s", serverName, err) return } if clientConfig.MinVersion != tls.VersionTLS12 { t.Errorf("expected TLS minimum version of %v, got %v", clientConfig.MinVersion, tls.VersionTLS12) } if serverConfig.MinVersion != tls.VersionTLS12 { t.Errorf("expected TLS minimum version of %v, got %v", serverConfig.MinVersion, tls.VersionTLS12) } for _, suite := range serverConfig.CipherSuites { if badCipherSuites[suite] { t.Errorf("server allows blacklisted cipher suite %v", suite) } } server := NewMockServer() server.SetListener(tls.NewListener(server.Listener(), serverConfig)) listener := server.Listener() testData := strings.Repeat("x", 1<<16) go func() { server.SignalReady() conn, err := listener.Accept() if err == nil { defer conn.Close() conn.SetDeadline(time.Now().Add(time.Second)) data, err := ioutil.ReadAll(conn) if err == nil && string(data) != testData { err = fmt.Errorf("Server read incorrect data; got '%s', expected '%s'", string(data), testData) } if err != nil && shouldPass { t.Errorf("Server read error: %v", err) } else if err == nil && !shouldPass { t.Errorf("Expected server read error: %v", err) } } else { t.Errorf("Listener error: %v", err) } server.WaitForShutdown() server.SignalFinish() }() server.WaitForReady() addr := listener.Addr().String() conn, err := tls.Dial("tcp", addr, clientConfig) if err != nil && shouldPass { t.Errorf("Client connection error: %v", err) return } else if err == nil && !shouldPass { t.Errorf("Expected client connection error: %v", err) return } else if err != nil { return } // else err == nil && shouldPass conn.SetDeadline(time.Now().Add(time.Second)) n, err := io.WriteString(conn, testData) if err == nil && n < len(testData) { err = fmt.Errorf("Client incomplete write: expected %d bytes, got %d", len(testData), n) } if err != nil && shouldPass { t.Errorf("Client write error: %v", err) } else if err == nil && !shouldPass { t.Errorf("Expected client write error: %v", err) } conn.Close() server.Shutdown() }
// test that CertCreator makes client/server certs that 1) work when used // correctly and 2) don't work when used with different CAs func TestCertCreator(t *testing.T) { cc := tls.NewCertCreator() cc.KeySize = 512 cc.Country = "US" cc.State = "CA" cc.City = "San Francisco" cc.Organization = "Fastly Testing" var wg sync.WaitGroup done := false type group struct { name string clientConfig *gotls.Config serverConfig *gotls.Config listener net.Listener errors chan error cleanup func() } // generate ca, client, and server keypairs with the given name, and stand // up a TLS listener setup := func(name string) (g *group, err error) { g = &group{ name: name, errors: make(chan error), cleanup: func() { for _, f := range []string{ "-ca-key.pem", "-ca-cert.pem", "client-key.pem", "client-cert.pem", "server-key.pem", "server-cert.pem", } { file := "testcerts/" + name + f if err := os.Remove(file); err != nil { t.Errorf("couldn't remove %q: %s", err) } } }, } // put certs in testcerts/ since LocatePackagedPEMDir looks there root, err := cc.GenerateRootKeyPair("testcerts/"+name+"-ca", name+" CA") if err != nil { t.Error(err) return } host := "0.0.0.0" _, err = cc.GenerateKeyPair(tls.CLIENT, root, "testcerts/"+name+"client", host, host) if err != nil { t.Error(err) return } _, err = cc.GenerateKeyPair(tls.SERVER, root, "testcerts/"+name+"server", host, host) if err != nil { t.Error(err) return } g.clientConfig, err = tls.ConfigureClient(name+"client", name+"-ca") if err != nil { t.Errorf("%s client: %s", name+"client", err) return } g.serverConfig, err = tls.ConfigureServer(name+"server", name+"-ca") if err != nil { t.Errorf("%s server: %s", name+"server", err) return } listener, err := gotls.Listen("tcp4", host+":0", g.serverConfig) if err != nil { t.Errorf("%s listen: %s", name, err) return } g.listener = listener g.clientConfig.ServerName = g.listener.Addr().(*net.TCPAddr).IP.String() // required for client to validate server cert wg.Add(1) go func() { defer wg.Done() for { conn, err := listener.Accept() if done { return } if err != nil { g.errors <- err return } if conn == nil { g.errors <- fmt.Errorf("no conn or err") return } wg.Add(1) go func(conn net.Conn) { defer wg.Done() defer conn.Close() conn.SetDeadline(time.Now().Add(time.Second)) err = conn.(*gotls.Conn).Handshake() if err != nil { g.errors <- err return } conn.SetDeadline(time.Now().Add(time.Second)) b := make([]byte, 3) // len("foo") n, err := conn.Read(b) b = b[:n] if err != nil { g.errors <- err return } if string(b) != "foo" { g.errors <- fmt.Errorf("%q != %q", b, "foo") return } conn.SetDeadline(time.Now().Add(time.Second)) n, err = io.WriteString(conn, "bar") if err != nil { g.errors <- err return } g.errors <- nil }(conn) } }() return } // try to create a TLS connection from client to server and send a bit of // data back and forth check := func(client, server *group, expectedClientError, expectedServerError string) { addr := server.listener.Addr().String() conn, err := net.DialTimeout("tcp4", addr, time.Second) if err != nil { t.Errorf("%s->%s: dial: %s", client.name, server.name, err) return } tlsConn := gotls.Client(conn, client.clientConfig) clientErr := tlsConn.Handshake() if clientErr == nil { tlsConn.SetDeadline(time.Now().Add(time.Second)) if _, err := io.WriteString(tlsConn, "foo"); err == nil { tlsConn.SetDeadline(time.Now().Add(time.Second)) b := make([]byte, 3) // len("bar") n, err := tlsConn.Read(b) b = b[:n] if err != nil { clientErr = err } if string(b) != "bar" { clientErr = fmt.Errorf("%q != %q", b, "bar") } } else { clientErr = err } tlsConn.Close() } var serverErr error select { case serverErr = <-server.errors: case <-time.After(time.Second): t.Errorf("timed out on serverErr") return } if expectedClientError == "" && clientErr != nil { t.Errorf("%s->%s client: should've worked but saw err=`%s`", client.name, server.name, clientErr) } if expectedClientError != "" { if clientErr == nil { t.Errorf("%s->%s client: should not have worked", client.name, server.name) } else if clientErr.Error() != expectedClientError { t.Errorf("%s->%s client: expected error %q but got %q", client.name, server.name, expectedClientError, clientErr) } } if expectedServerError == "" && serverErr != nil { t.Errorf("%s->%s server: should've worked but saw err=`%s`", client.name, server.name, serverErr) } if expectedServerError != "" { if serverErr == nil { t.Errorf("%s->%s server: should not have worked", client.name, server.name) } else if serverErr.Error() != expectedServerError { t.Errorf("%s->%s server: expected error %q but got %q", client.name, server.name, expectedServerError, serverErr) } } } // test that friend client talks to friend server, but friend client does // not talk to foe server or foe client to friend server friend, err := setup("friend") if err != nil { t.Errorf("setup: %s", err) return } defer friend.cleanup() foe, err := setup("foe") if err != nil { t.Errorf("setup: %s", err) return } defer foe.cleanup() // friend to friend: should work check(friend, friend, "", "") // friend to foe: friend client should reject foe server check(friend, foe, "x509: certificate signed by unknown authority", "remote error: bad certificate") // foe to friend: foe client accepts friend server but server should reject // foe client. // prevent client from rejecting the server's "bad" cert so we can test // that the server rejects the client's bad cert. foe.clientConfig.InsecureSkipVerify = true // this expects no certificate from the client because // tls/handshake_server.go sends a list of acceptable CA certs to the // client when requesting a client cert and tls/handshake_client.go avoids // sending any cert if no eligible one is in its Config.Certificates. // // a more complete test would be to force the client to send a cert it // knows is bad, but that's a hassle in a pure test so I verified it // manually with a Go 1.2.1 patched to not send the list of CAs from the // server and got: /* --- FAIL: TestCertCreator (0.20 seconds) cert_creator_test.go:224: foe->friend server: expected error "tls: client didn't provide a certificate" but got "tls: failed to verify client's certificate: x509: certificate signed by unknown authority" */ check(foe, friend, "remote error: bad certificate", "tls: client didn't provide a certificate") foe.clientConfig.InsecureSkipVerify = false done = true friend.listener.Close() foe.listener.Close() c := make(chan struct{}) go func() { wg.Wait() c <- struct{}{} }() select { case <-c: case <-time.After(time.Second): t.Errorf("timed out on wg.Wait") } }