// RoundTrip implements the RoundTripper interface. func (t *SSRoundTriper) RoundTrip(req *http.Request, ctx *goproxy.ProxyCtx) (resp *http.Response, err error) { host := req.URL.Host if byteIndex(host, ':') < 0 { host = host + ":80" } rawaddr, err := ss.RawAddr(host) if err != nil { panic("Error getting raw address.") } tr := &http.Transport{ Dial: func(_, _ string) (net.Conn, error) { return createServerConn(rawaddr, host) }, } client := &http.Client{ Transport: tr, } return client.Do(req) }
func main() { flag.StringVar(&config.server, "s", "127.0.0.1", "server:port") flag.IntVar(&config.port, "p", 0, "server:port") flag.IntVar(&config.core, "core", 1, "number of CPU cores to use") flag.StringVar(&config.passwd, "k", "", "password") flag.StringVar(&config.method, "m", "", "encryption method, use empty string or rc4") flag.IntVar(&config.nconn, "nc", 1, "number of connection to server") flag.IntVar(&config.nreq, "nr", 1, "number of request for each connection") // flag.IntVar(&config.nsec, "ns", 0, "run how many seconds for each connection") flag.BoolVar((*bool)(&debug), "d", false, "print http response body for debugging") flag.Parse() if config.server == "" || config.port == 0 || config.passwd == "" || len(flag.Args()) != 1 { fmt.Printf("Usage: %s -s <server> -p <port> -k <password> <url>\n", os.Args[0]) os.Exit(1) } if err := ss.SetDefaultCipher(config.method); err != nil { fmt.Println(err) os.Exit(1) } runtime.GOMAXPROCS(config.core) uri := flag.Arg(0) if !strings.HasPrefix(uri, "http://") { uri = "http://" + uri } cipher, err := ss.NewCipher(config.passwd) if err != nil { fmt.Println("Error creating cipher:", err) os.Exit(1) } serverAddr := net.JoinHostPort(config.server, strconv.Itoa(config.port)) parsedURL, err := url.Parse(uri) if err != nil { fmt.Println("Error parsing url:", err) os.Exit(1) } host, _, err := net.SplitHostPort(parsedURL.Host) if err != nil { host = net.JoinHostPort(parsedURL.Host, "80") } else { host = parsedURL.Host } // fmt.Println(host) rawAddr, err := ss.RawAddr(host) if err != nil { panic("Error getting raw address.") return } done := make(chan []time.Duration) for i := 1; i <= config.nconn; i++ { go get(i, uri, serverAddr, rawAddr, cipher, done) } // collect request finish time reqTime := make([]int64, config.nconn*config.nreq) reqDone := 0 for i := 1; i <= config.nconn; i++ { rt := <-done for _, t := range rt { reqTime[reqDone] = int64(t) reqDone++ } } fmt.Println("number of total requests:", config.nconn*config.nreq) fmt.Println("number of finished requests:", reqDone) if reqDone == 0 { return } // calculate average an standard deviation reqTime = reqTime[:reqDone] var sum int64 for _, d := range reqTime { sum += d } avg := float64(sum) / float64(reqDone) varSum := float64(0) for _, d := range reqTime { di := math.Abs(float64(d) - avg) di *= di varSum += di } stddev := math.Sqrt(varSum / float64(reqDone)) fmt.Println("\naverage time per request:", time.Duration(avg)) fmt.Println("standard deviation:", time.Duration(stddev)) }
func TestSsProxy(t *testing.T) { // 测试域名 testAddr := "www.test123.com:80" testData := []byte("dsgbhdfhgsq36jhrawdxghucn46ggetst") testServerAddr := "127.0.0.1:1458" testMethod := "aes-256-cfb" testPassword := "******" testRawAddr, err := ss.RawAddr(testAddr) if err != nil { t.Fatal(err) } // 简单模拟服务器 cipher, err := ss.NewCipher(testMethod, testPassword) if err != nil { t.Fatal(err) } l, err := net.Listen("tcp", testServerAddr) if err != nil { t.Fatal(err) } defer l.Close() go func() { c, err := l.Accept() if err != nil { t.Fatal(err) } c.SetDeadline(time.Now().Add(5 * time.Second)) sc := ss.NewConn(c, cipher) defer sc.Close() // 读内容并返回 for i := 0; i < 2; i++ { buf := make([]byte, 1024) if n, err := sc.Read(buf); err != nil { t.Fatal("i=", i, "服务器读内容错误:", err) } else { if _, err := sc.Write(buf[:n]); err != nil { t.Fatal(err) } } } }() // 发出请求,然后解密。 p, err := newSsProxyClient(testServerAddr, testMethod, testPassword, nil, nil) if err != nil { t.Fatal(err) } c, err := p.DialTimeout("tcp", testAddr, 1*time.Second) if err != nil { t.Fatal(err) } defer c.Close() // 比较地址是否正确 buf := make([]byte, 1024) if n, err := c.Read(buf); err != nil { t.Fatal("读地址错误:", err) } else { if bytes.Compare(buf[:n], testRawAddr) != 0 { t.Fatal("地址未正确发送") } } // 发送测试数据,并读取比较 if _, err := c.Write(testData); err != nil { t.Fatal(err) } if n, err := c.Read(buf); err != nil { t.Fatal(err) } else { if bytes.Compare(buf[:n], testData) != 0 { t.Fatal("数据未正确发送") } } }
func (p *ssProxyClient) DialTCPSAddrTimeout(network string, raddr string, timeout time.Duration) (rconn ProxyTCPConn, rerr error) { // 截止时间 finalDeadline := time.Time{} if timeout != 0 { finalDeadline = time.Now().Add(timeout) } ra, err := ss.RawAddr(raddr) if err != nil { return } c, err := p.upProxy.DialTCPSAddrTimeout(network, p.proxyAddr, timeout) if err != nil { return nil, fmt.Errorf("无法连接代理服务器 %v ,错误:%v", p.proxyAddr, err) } ch := make(chan int) defer close(ch) // 实际执行部分 run := func() { sc := ss.NewConn(c, p.cipher.Copy()) closed := false // 当连接不被使用时,ch<-1会引发异常,这时将关闭连接。 defer func() { e := recover() if e != nil && closed == false { sc.Close() } }() if _, err := sc.Write(ra); err != nil { closed = true sc.Close() rerr = err ch <- 0 return } r := ssTCPConn{TCPConn: c, sc: sc, proxyClient: p} //{c,net.ResolveTCPAddr("tcp","0.0.0.0:0"),net.ResolveTCPAddr("tcp","0.0.0.0:0"),"","",0,0 p} rconn = &r ch <- 1 } if timeout == 0 { go run() select { case <-ch: return } } else { c.SetDeadline(finalDeadline) ntimeout := finalDeadline.Sub(time.Now()) if ntimeout <= 0 { return nil, fmt.Errorf("timeout") } t := time.NewTimer(ntimeout) defer t.Stop() go run() select { case <-t.C: return nil, fmt.Errorf("连接超时。") case <-ch: if rerr == nil { c.SetDeadline(time.Time{}) } return } } }