// Taken and reworked from: https://gist.github.com/madmo/8548738 func websocketHTTPConnect(proxy, urlString string) (net.Conn, error) { p, err := net.Dial("tcp", proxy) if err != nil { return nil, err } turl, err := url.Parse(urlString) if err != nil { return nil, err } req := http.Request{ Method: "CONNECT", URL: &url.URL{}, Host: turl.Host, } cc := httputil.NewProxyClientConn(p, nil) cc.Do(&req) if err != nil && err != httputil.ErrPersistEOF { return nil, err } rwc, _ := cc.Hijack() return rwc, nil }
// dial dials the host specified by req, using TLS if appropriate, optionally // using a proxy server if one is configured via environment variables. func (s *SpdyRoundTripper) dial(req *http.Request) (net.Conn, error) { proxyURL, err := http.ProxyFromEnvironment(req) if err != nil { return nil, err } if proxyURL == nil { return s.dialWithoutProxy(req.URL) } // ensure we use a canonical host with proxyReq targetHost := netutil.CanonicalAddr(req.URL) // proxying logic adapted from http://blog.h6t.eu/post/74098062923/golang-websocket-with-http-proxy-support proxyReq := http.Request{ Method: "CONNECT", URL: &url.URL{}, Host: targetHost, } proxyDialConn, err := s.dialWithoutProxy(proxyURL) if err != nil { return nil, err } proxyClientConn := httputil.NewProxyClientConn(proxyDialConn, nil) _, err = proxyClientConn.Do(&proxyReq) if err != nil && err != httputil.ErrPersistEOF { return nil, err } rwc, _ := proxyClientConn.Hijack() if req.URL.Scheme != "https" { return rwc, nil } host, _, err := net.SplitHostPort(req.URL.Host) if err != nil { return nil, err } if len(s.tlsConfig.ServerName) == 0 { s.tlsConfig.ServerName = host } tlsConn := tls.Client(rwc, s.tlsConfig) // need to manually call Handshake() so we can call VerifyHostname() below if err := tlsConn.Handshake(); err != nil { return nil, err } if err := tlsConn.VerifyHostname(host); err != nil { return nil, err } return tlsConn, nil }
// function to do the actual testing and output the necessary bits // we don't really care about ordering here otherwise we'd need to do a better job of organizing the output func (w *Worker) testHost(work *Work) { // more debug w.dbg(3, "Ok.. prepping to test %s\n", work.target) // pre-define myConn/tlsConn/client var myConn net.Conn var tlsConn *tls.Conn var client *httputil.ClientConn // need a couple of vars for holding ip addresses and default them to being empty var ip4 string = "" var ip6 string = "" // and a var to stash our full host:port combo in var address string = "" // create a buffer we'll use later for reading client body buffer := make([]byte, 1024) // build our request request, err := http.NewRequest("GET", work.url, nil) if err != nil { w.dbgError("Failed to build request for %s : %s\n", work.url, err) return } if config.ua == "" { request.Header.Set("User-Agent", version) } else { request.Header.Set("User-Agent", config.ua) } // debug w.dbg(3, "Request: %+v\n", request) // set a timer for the connection timestart := time.Now() ipaddr, ierr := net.LookupHost(work.target) if ierr != nil { w.dbgError("Failed to lookup %s : %s", work.target, ierr) return } for _, x := range ipaddr { if len(x) > 16 { if ip6 == "" { ip6 = x } } else if ip4 == "" { ip4 = x } } timestop := time.Now() iplookup := (timestop.Sub(timestart)) // build our address string if config.tcpfour { address = fmt.Sprintf("%s:%d", ip4, work.port) } else { address = fmt.Sprintf("[%s]:%d", ip6, work.port) } // debug w.dbg(1, "Attempting to connect to %s\n", address) // since we want some very low level access to bits and pieces, we're going to have to use tcp dial vs the native http client // create a net.conn w.dbg(1, "Connecting to %s\n", address) if config.tcpfour { myConn, err = net.DialTimeout("tcp4", address, time.Duration(config.timeout)*time.Second) } else if config.tcpsix { myConn, err = net.DialTimeout("tcp6", address, time.Duration(config.timeout)*time.Second) } else { myConn, err = net.DialTimeout("tcp", address, time.Duration(config.timeout)*time.Second) } if err != nil { w.dbgError("Could not connect to %s : %s\n", address, err) return } w.dbg(2, "Connected to %s\n", address) // get a time reading on how long it took to connect to the socket timestop = time.Now() tcpConnect := (timestop.Sub(timestart)) // defer close defer myConn.Close() // need to add some deadlines so we don't sit around indefintely - 5s is more than sufficient myConn.SetDeadline(time.Now().Add(time.Duration(5 * time.Second))) // if we're an ssl connection, we need a few extra steps here if work.ssl { w.dbg(1, "Starting SSL procedures...\n") // default to allowing insecure ssl tlsConfig := tls.Config{InsecureSkipVerify: true} // create a real tls connection tlsConn = tls.Client(myConn, &tlsConfig) // do our SSL negotiation err = tlsConn.Handshake() if err != nil { w.dbgError("Could not negotiate tls handshake on %s : %s\n", address, err) return } // defer closing this connection as well defer tlsConn.Close() } // get a time reading on how long it took to negotiate ssl timestop = time.Now() sslHandshake := (timestop.Sub(timestart)) // get our state if work.ssl { state := tlsConn.ConnectionState() w.dbg(2, "Handshake Complete: %t\n", state.HandshakeComplete) w.dbg(2, "Mutual: %t\n", state.NegotiatedProtocolIsMutual) } // debug w.dbg(3, "Converting to an HTTP client connection...\n") // convert to an http connection if work.ssl { client = httputil.NewProxyClientConn(tlsConn, nil) } else { client = httputil.NewProxyClientConn(myConn, nil) } // debug w.dbg(1, "Making GET request\n") // write our request to the socket err = client.Write(request) if err != nil { w.dbgError("Error writing request : %s\n", err) return } // read our response headers response, err := client.Read(request) if err != nil { // did we get a 400? if response.StatusCode == 400 { w.dbgError("400 response received.. \n") return } // did we get a 404? if response.StatusCode == 404 { w.dbgError("404 response received.. \n") return } // any other error, exit out w.dbgError("Error reading response : %s\n", err) return } w.dbg(1, "Status: %s\n", response.Status) // did we get a response? if len(response.Header) == 0 { w.dbgError("0 length response, something probably broke") return } // measure response header time timestop = time.Now() respTime := (timestop.Sub(timestart)) // defer close since we still want to read the body of the object defer response.Body.Close() // build a reader br := bufio.NewReader(response.Body) // now read the first byte c, err := br.ReadByte() if err != nil { w.dbgError("Could not read data: %s\n", err) return } // measure our first byte time, this is normally 0ms however longer periods could be indicative of a problem timestop = time.Now() byteTime := (timestop.Sub(timestart)) // ok, read the rest of the response n, err := br.Read(buffer) count := n for err != io.EOF { n, err = br.Read(buffer) count += n } // did we fail to read everything? if err != nil && err != io.EOF { w.dbgError("Error on data read, continuing with only %d bytes of %s read\n", count+1, response.Header.Get("Content-Length")) } // measure our overall time to proccess the entire transaction timestop = time.Now() totalTime := (timestop.Sub(timestart)) w.dbg(2, "Received %d bytes total\n", count+1) w.dbg(3, "Response: %s%s\n", string(c), string(buffer)) // shut down the client client.Close() // properly close out our other connections myConn.Close() if work.ssl { tlsConn.Close() } if work.gph != nil { // Graphite selected w.dbg(2, "Graphite selected...\n") gname := strings.Replace(work.target, ".", "_", -1) prefix := "" if config.prefix != "" { prefix = fmt.Sprintf("%s/", config.prefix) } if work.ssl { work.gph.PostOne(fmt.Sprintf("%s%s/ssl_port_%d/dns_time", prefix, gname, work.port), float64(iplookup/1000000)) work.gph.PostOne(fmt.Sprintf("%s%s/ssl_port_%d/connect_time", prefix, gname, work.port), float64(tcpConnect/1000000)) work.gph.PostOne(fmt.Sprintf("%s%s/ssl_port_%d/ssl_time", prefix, gname, work.port), float64(sslHandshake/1000000)) work.gph.PostOne(fmt.Sprintf("%s%s/ssl_port_%d/response_time", prefix, gname, work.port), float64(respTime/1000000)) work.gph.PostOne(fmt.Sprintf("%s%s/ssl_port_%d/byte_time", prefix, gname, work.port), float64(byteTime/1000000)) work.gph.PostOne(fmt.Sprintf("%s%s/ssl_port_%d/total_time", prefix, gname, work.port), float64(totalTime/1000000)) } else { work.gph.PostOne(fmt.Sprintf("%s%s/http_port_%d/dns_time", prefix, gname, work.port), float64(iplookup/1000000)) work.gph.PostOne(fmt.Sprintf("%s%s/http_port_%d/connect_time", prefix, gname, work.port), float64(tcpConnect/1000000)) work.gph.PostOne(fmt.Sprintf("%s%s/http_port_%d/response_time", prefix, gname, work.port), float64(respTime/1000000)) work.gph.PostOne(fmt.Sprintf("%s%s/http_port_%d/byte_time", prefix, gname, work.port), float64(byteTime/1000000)) work.gph.PostOne(fmt.Sprintf("%s%s/http_port_%d/total_time", prefix, gname, work.port), float64(totalTime/1000000)) } // if we've defined multiple, output in all defined formats + stdout if !config.multiple { return } } // PublishMetric(host string, instance string, key string, value int64) (error Error) { // add tsdb writing if config.tval != "" { w.dbg(2, "Writing to tsdb server...\n") var METRIC []interface{} var tags map[string]interface{} var metric map[string]interface{} tstamp := int64(time.Now().Unix()) if work.ssl { tags = map[string]interface{}{"host": work.target, "port": work.port, "type": "ssl"} metric = map[string]interface{}{"metric": "handshake_time", "value": int64(sslHandshake / 1000000), "timestamp": tstamp, "tags": tags} METRIC = append(METRIC, metric) } else { tags = map[string]interface{}{"Host": work.target, "Port": work.port, "Type": "http"} } metric = map[string]interface{}{"metric": "dns_lookup", "value": int64(iplookup / 1000000), "timestamp": tstamp, "tags": tags} METRIC = append(METRIC, metric) metric = map[string]interface{}{"metric": "connect_time", "value": int64(tcpConnect / 1000000), "timestamp": tstamp, "tags": tags} METRIC = append(METRIC, metric) metric = map[string]interface{}{"metric": "response_time", "value": int64(respTime / 1000000), "timestamp": tstamp, "tags": tags} METRIC = append(METRIC, metric) metric = map[string]interface{}{"metric": "firstbyte_time", "value": int64(byteTime / 1000000), "timestamp": tstamp, "tags": tags} METRIC = append(METRIC, metric) metric = map[string]interface{}{"metric": "total_time", "value": int64(totalTime / 1000000), "timestamp": tstamp, "tags": tags} METRIC = append(METRIC, metric) b, _ := json.Marshal(METRIC) err = writeTSDB(b) if err != nil { dbg(1, "Error writing to tsdb server: %s\n", err) } // if we've defined multiple, output in all defined formats + stdout if !config.multiple { return } } // add collectd style output if config.collectd { if work.ssl { fmt.Fprintf(os.Stdout, "PUTVAL %s/ssl_port_%d/milliseconds-dns_time interval=%d N:%d\n", work.target, work.port, config.interval, iplookup/1000000) fmt.Fprintf(os.Stdout, "PUTVAL %s/ssl_port_%d/milliseconds-connect_time interval=%d N:%d\n", work.target, work.port, config.interval, tcpConnect/1000000) fmt.Fprintf(os.Stdout, "PUTVAL %s/ssl_port_%d/milliseconds-ssl_time interval=%d N:%d\n", work.target, work.port, config.interval, sslHandshake/1000000) fmt.Fprintf(os.Stdout, "PUTVAL %s/ssl_port_%d/milliseconds-response_time interval=%d N:%d\n", work.target, work.port, config.interval, respTime/1000000) fmt.Fprintf(os.Stdout, "PUTVAL %s/ssl_port_%d/milliseconds-byte_time interval=%d N:%d\n", work.target, work.port, config.interval, byteTime/1000000) fmt.Fprintf(os.Stdout, "PUTVAL %s/ssl_port_%d/milliseconds-total_time interval=%d N:%d\n", work.target, work.port, config.interval, totalTime/1000000) } else { fmt.Fprintf(os.Stdout, "PUTVAL %s/http_port_%d/milliseconds-dns_time interval=%d N:%d\n", work.target, work.port, config.interval, iplookup/1000000) fmt.Fprintf(os.Stdout, "PUTVAL %s/http_port_%d/milliseconds-connect_time interval=%d N:%d\n", work.target, work.port, config.interval, tcpConnect/1000000) fmt.Fprintf(os.Stdout, "PUTVAL %s/http_port_%d/milliseconds-response_time interval=%d N:%d\n", work.target, work.port, config.interval, respTime/1000000) fmt.Fprintf(os.Stdout, "PUTVAL %s/http_port_%d/milliseconds-byte_time interval=%d N:%d\n", work.target, work.port, config.interval, byteTime/1000000) fmt.Fprintf(os.Stdout, "PUTVAL %s/http_port_%d/milliseconds-total_time interval=%d N:%d\n", work.target, work.port, config.interval, totalTime/1000000) } // if we've defined multiple, output in all defined formats + stdout if !config.multiple { return } } // add function for zabbix if config.zabbix { if work.ssl { fmt.Fprintf(os.Stdout, "https_host=%s, port=%d, dns_lookup=%dms, socket_connect=%dms, ssl_negotiation=%dms, response_time=%dms, first_byte=%dms, total_time=%dms ", work.target, work.port, iplookup/1000000, tcpConnect/1000000, sslHandshake/1000000, respTime/1000000, byteTime/1000000, totalTime/1000000) if config.expires { i := 0 state := tlsConn.ConnectionState() for _, v := range state.PeerCertificates { if i == 0 { myT := time.Now() myExpires := v.NotAfter.Sub(myT) fmt.Fprintf(os.Stdout, "expires=%dd\n", myExpires/8.64e13) i++ } } } else { fmt.Fprintf(os.Stdout, "\n") } } else { fmt.Fprintf(os.Stdout, "host=%s, port=%d, dns_lookup=%dms, socket_connect=%dms, response_time=%dms, first_byte=%dms, total_time=%dms\n", work.target, work.port, iplookup/1000000, tcpConnect/1000000, respTime/1000000, byteTime/1000000, totalTime/1000000) } // if we've defined multiple, output in all defined formats + stdout if !config.multiple { return } } // print out our values if (!config.multiple) || (config.verbose) { if work.ssl { log.Printf("Host: %s -> DNS Lookup: %dms, Socket Connect: %dms, SSL Negotiation: %dms, Response Time: %dms, 1st Byte: %dms, Total Time: %dms\n", work.target, iplookup/1000000, tcpConnect/1000000, sslHandshake/1000000, respTime/1000000, byteTime/1000000, totalTime/1000000) } else { log.Printf("Host: %s -> DNS Lookup: %dms, Socket Connect: %dms, Response Time: %dms, 1st Byte: %dms, Total Time: %dms\n", work.target, iplookup/1000000, tcpConnect/1000000, respTime/1000000, byteTime/1000000, totalTime/1000000) } if work.ssl { if config.expires == true { i := 0 state := tlsConn.ConnectionState() for _, v := range state.PeerCertificates { if i == 0 { myT := time.Now() myExpires := v.NotAfter.Sub(myT) fmt.Fprintf(os.Stdout, "Cert expires in: %d days\n", myExpires/8.64e13) i++ } } } } } if config.verbose { log.Printf("StatusCode: %d\n", response.StatusCode) log.Printf("ProtoCol: %s\n", response.Proto) for k, v := range response.Header { log.Printf("%s: %v\n", k, v) } if work.ssl { if config.printcert == true { i := 0 state := tlsConn.ConnectionState() for _, v := range state.PeerCertificates { if i == 0 { sslFrom := v.NotBefore sslTo := v.NotAfter log.Printf("Server key information:") log.Printf("\tCN:\t%v\n\tOU:\t%v\n\tOrg:\t%v\n", v.Subject.CommonName, v.Subject.OrganizationalUnit, v.Subject.Organization) log.Printf("\tCity:\t%v\n\tState:\t%v\n\tCountry:%v\n", v.Subject.Locality, v.Subject.Province, v.Subject.Country) log.Printf("SSL Certificate Valid:\n\tFrom: %v\n\tTo: %v\n", sslFrom, sslTo) log.Printf("Valid Certificate DNS:\n") if len(v.DNSNames) >= 1 { for dns := range v.DNSNames { log.Printf("\t%v\n", v.DNSNames[dns]) } } else { log.Printf("\t%v\n", v.Subject.CommonName) } i++ } else if i == 1 { log.Printf("Issued by:\n\t%v\n\t%v\n\t%v\n", v.Subject.CommonName, v.Subject.OrganizationalUnit, v.Subject.Organization) i++ } else { // we're done here, lets move on break } } } } // throw in a new line to pretty it up log.Printf("") } return }
func (p *Proxy) ServeHTTP(cwr http.ResponseWriter, creq *http.Request) { // c = things towards the client of the proxy // o = things towards origin server if creq.Method == "CONNECT" { rc, err := net.Dial("tcp", creq.URL.Host) if err != nil { http.Error(cwr, err.Error(), http.StatusGatewayTimeout) loghit(creq, http.StatusGatewayTimeout) return } remote := bufio.NewReadWriter(bufio.NewReader(rc), bufio.NewWriter(rc)) cwr.WriteHeader(http.StatusOK) loghit(creq, http.StatusOK) hj, ok := cwr.(http.Hijacker) if !ok { panic("not hijackable") } wc, client, err := hj.Hijack() done := make(chan int) f := func(from, to *bufio.ReadWriter) { var err error n := 0 for err == nil { var c byte c, err = from.ReadByte() n++ if err == nil { err = to.WriteByte(c) to.Flush() } } done <- n } go f(remote, client) go f(client, remote) // wait for one side to finish and close both sides tot := <-done wc.Close() rc.Close() tot += <-done log.Print("CONNECT finished, ", tot, " bytes") return } oreq := new(http.Request) oreq.ProtoMajor = 1 oreq.ProtoMinor = 1 oreq.Close = true oreq.Header = creq.Header oreq.Method = creq.Method ourl, err := url.Parse(creq.RequestURI) if err != nil { http.Error(cwr, fmt.Sprint("Malformed request", err), http.StatusNotImplemented) loghit(creq, http.StatusNotImplemented) return } oreq.URL = ourl if oreq.URL.Scheme != "http" { http.Error(cwr, "I only proxy http", http.StatusNotImplemented) loghit(creq, http.StatusNotImplemented) return } if oreq.Method != "GET" && oreq.Method != "POST" { log.Print("Cannot handle method ", creq.Method) http.Error(cwr, "I only handle GET and POST", http.StatusNotImplemented) return } if oreq.Method == "POST" { oreq.Method = "POST" if _, ok := oreq.Header["Content-Type"]; !ok { oreq.Header.Set("Content-Type", "multipart/form-data") } oreq.ContentLength = creq.ContentLength oreq.Body = creq.Body } addr := oreq.URL.Host if !hasPort(addr) { addr += ":" + oreq.URL.Scheme } c, err := net.Dial("tcp", addr) if err != nil { http.Error(cwr, err.Error(), http.StatusGatewayTimeout) loghit(creq, http.StatusGatewayTimeout) return } c.SetReadDeadline(time.Now().Add(3 * time.Second)) cc := httputil.NewProxyClientConn(c, nil) // debug //dbg, err := http.DumpRequest(oreq, true) //log.Print("Dump request to origin server:\n", string(dbg)) err = cc.Write(oreq) if err != nil { http.Error(cwr, err.Error(), http.StatusGatewayTimeout) loghit(creq, http.StatusGatewayTimeout) return } oresp, err := cc.Read(oreq) if err != nil && err != httputil.ErrPersistEOF { http.Error(cwr, err.Error(), http.StatusGatewayTimeout) loghit(creq, http.StatusGatewayTimeout) return } //dbg, err = http.DumpResponse(oresp, true) //log.Print("Dump response from origin server:\n", string(dbg)) for hdr, val := range oresp.Header { if !doNotCopy[hdr] { h := cwr.Header() h[hdr] = val } } cwr.WriteHeader(oresp.StatusCode) // simulate it coming in over gLink, a shared rate-limited link io.Copy(cwr, gLink.NewLinkReader(oresp.Body)) cc.Close() c.Close() loghit(creq, oresp.StatusCode) }