func DialURL(url *url.URL, transport http.RoundTripper) (net.Conn, error) { dialAddr := netutil.CanonicalAddr(url) dialer, _ := utilnet.Dialer(transport) switch url.Scheme { case "http": if dialer != nil { return dialer("tcp", dialAddr) } return net.Dial("tcp", dialAddr) case "https": // Get the tls config from the transport if we recognize it var tlsConfig *tls.Config var tlsConn *tls.Conn var err error tlsConfig, _ = utilnet.TLSClientConfig(transport) if dialer != nil { // We have a dialer; use it to open the connection, then // create a tls client using the connection. netConn, err := dialer("tcp", dialAddr) if err != nil { return nil, err } if tlsConfig == nil { // tls.Client requires non-nil config glog.Warningf("using custom dialer with no TLSClientConfig. Defaulting to InsecureSkipVerify") // tls.Handshake() requires ServerName or InsecureSkipVerify tlsConfig = &tls.Config{ InsecureSkipVerify: true, } } else if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify { // tls.Handshake() requires ServerName or InsecureSkipVerify // infer the ServerName from the hostname we're connecting to. inferredHost := dialAddr if host, _, err := net.SplitHostPort(dialAddr); err == nil { inferredHost = host } // Make a copy to avoid polluting the provided config tlsConfigCopy := utilnet.CloneTLSConfig(tlsConfig) tlsConfigCopy.ServerName = inferredHost tlsConfig = tlsConfigCopy } tlsConn = tls.Client(netConn, tlsConfig) if err := tlsConn.Handshake(); err != nil { netConn.Close() return nil, err } } else { // Dial tlsConn, err = tls.Dial("tcp", dialAddr, tlsConfig) if err != nil { return nil, err } } // Return if we were configured to skip validation if tlsConfig != nil && tlsConfig.InsecureSkipVerify { return tlsConn, nil } // Verify host, _, _ := net.SplitHostPort(dialAddr) if err := tlsConn.VerifyHostname(host); err != nil { tlsConn.Close() return nil, err } return tlsConn, nil default: return nil, fmt.Errorf("Unknown scheme: %s", url.Scheme) } }
func TestDialURL(t *testing.T) { roots := x509.NewCertPool() if !roots.AppendCertsFromPEM(localhostCert) { t.Fatal("error setting up localhostCert pool") } cert, err := tls.X509KeyPair(localhostCert, localhostKey) if err != nil { t.Fatal(err) } testcases := map[string]struct { TLSConfig *tls.Config Dial func(network, addr string) (net.Conn, error) ExpectError string }{ "insecure": { TLSConfig: &tls.Config{InsecureSkipVerify: true}, }, "secure, no roots": { TLSConfig: &tls.Config{InsecureSkipVerify: false}, ExpectError: "unknown authority", }, "secure with roots": { TLSConfig: &tls.Config{InsecureSkipVerify: false, RootCAs: roots}, }, "secure with mismatched server": { TLSConfig: &tls.Config{InsecureSkipVerify: false, RootCAs: roots, ServerName: "bogus.com"}, ExpectError: "not bogus.com", }, "secure with matched server": { TLSConfig: &tls.Config{InsecureSkipVerify: false, RootCAs: roots, ServerName: "example.com"}, }, "insecure, custom dial": { TLSConfig: &tls.Config{InsecureSkipVerify: true}, Dial: net.Dial, }, "secure, no roots, custom dial": { TLSConfig: &tls.Config{InsecureSkipVerify: false}, Dial: net.Dial, ExpectError: "unknown authority", }, "secure with roots, custom dial": { TLSConfig: &tls.Config{InsecureSkipVerify: false, RootCAs: roots}, Dial: net.Dial, }, "secure with mismatched server, custom dial": { TLSConfig: &tls.Config{InsecureSkipVerify: false, RootCAs: roots, ServerName: "bogus.com"}, Dial: net.Dial, ExpectError: "not bogus.com", }, "secure with matched server, custom dial": { TLSConfig: &tls.Config{InsecureSkipVerify: false, RootCAs: roots, ServerName: "example.com"}, Dial: net.Dial, }, } for k, tc := range testcases { func() { ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {})) defer ts.Close() ts.TLS = &tls.Config{Certificates: []tls.Certificate{cert}} ts.StartTLS() tlsConfigCopy := utilnet.CloneTLSConfig(tc.TLSConfig) transport := &http.Transport{ Dial: tc.Dial, TLSClientConfig: tlsConfigCopy, } extractedDial, err := utilnet.Dialer(transport) if err != nil { t.Fatal(err) } if fmt.Sprintf("%p", extractedDial) != fmt.Sprintf("%p", tc.Dial) { t.Fatalf("%s: Unexpected dial", k) } extractedTLSConfig, err := utilnet.TLSClientConfig(transport) if err != nil { t.Fatal(err) } if extractedTLSConfig == nil { t.Fatalf("%s: Expected tlsConfig", k) } u, _ := url.Parse(ts.URL) _, p, _ := net.SplitHostPort(u.Host) u.Host = net.JoinHostPort("127.0.0.1", p) conn, err := DialURL(u, transport) // Make sure dialing doesn't mutate the transport's TLSConfig if !reflect.DeepEqual(tc.TLSConfig, tlsConfigCopy) { t.Errorf("%s: transport's copy of TLSConfig was mutated\n%#v\n\n%#v", k, tc.TLSConfig, tlsConfigCopy) } if err != nil { if tc.ExpectError == "" { t.Errorf("%s: expected no error, got %q", k, err.Error()) } if !strings.Contains(err.Error(), tc.ExpectError) { t.Errorf("%s: expected error containing %q, got %q", k, tc.ExpectError, err.Error()) } return } conn.Close() if tc.ExpectError != "" { t.Errorf("%s: expected error %q, got none", k, tc.ExpectError) } }() } }