func getRequest(conn *ss.Conn, auth bool) (host string, ota bool, err error) { ss.SetReadTimeout(conn) // buf size should at least have the same size with the largest possible // request size (when addrType is 3, domain name has at most 256 bytes) // 1(addrType) + 1(lenByte) + 256(max length address) + 2(port) + 10(hmac-sha1) buf := make([]byte, 270) // read till we get possible domain length field if _, err = io.ReadFull(conn, buf[:idType+1]); err != nil { return } var reqStart, reqEnd int addrType := buf[idType] switch addrType & ss.AddrMask { case typeIPv4: reqStart, reqEnd = idIP0, idIP0+lenIPv4 case typeIPv6: reqStart, reqEnd = idIP0, idIP0+lenIPv6 case typeDm: if _, err = io.ReadFull(conn, buf[idType+1:idDmLen+1]); err != nil { return } reqStart, reqEnd = idDm0, int(idDm0+buf[idDmLen]+lenDmBase) default: err = fmt.Errorf("addr type %d not supported", addrType&ss.AddrMask) return } if _, err = io.ReadFull(conn, buf[reqStart:reqEnd]); err != nil { return } // Return string for typeIP is not most efficient, but browsers (Chrome, // Safari, Firefox) all seems using typeDm exclusively. So this is not a // big problem. switch addrType & ss.AddrMask { case typeIPv4: host = net.IP(buf[idIP0 : idIP0+net.IPv4len]).String() case typeIPv6: host = net.IP(buf[idIP0 : idIP0+net.IPv6len]).String() case typeDm: host = string(buf[idDm0 : idDm0+buf[idDmLen]]) } // parse port port := binary.BigEndian.Uint16(buf[reqEnd-2 : reqEnd]) host = net.JoinHostPort(host, strconv.Itoa(int(port))) // if specified one time auth enabled, we should verify this if auth || addrType&ss.OneTimeAuthMask > 0 { ota = true if _, err = io.ReadFull(conn, buf[reqEnd:reqEnd+lenHmacSha1]); err != nil { return } iv := conn.GetIv() key := conn.GetKey() actualHmacSha1Buf := ss.HmacSha1(append(iv, key...), buf[:reqEnd]) if !bytes.Equal(buf[reqEnd:reqEnd+lenHmacSha1], actualHmacSha1Buf) { err = fmt.Errorf("verify one time auth failed, iv=%v key=%v data=%v", iv, key, buf[:reqEnd]) return } } return }
func handleConnection(conn *ss.Conn, auth bool) { var host string connCnt++ // this maybe not accurate, but should be enough if connCnt-nextLogConnCnt >= 0 { // XXX There's no xadd in the atomic package, so it's difficult to log // the message only once with low cost. Also note nextLogConnCnt maybe // added twice for current peak connection number level. log.Printf("Number of client connections reaches %d\n", nextLogConnCnt) nextLogConnCnt += logCntDelta } // function arguments are always evaluated, so surround debug statement // with if statement if debug { debug.Printf("new client %s->%s\n", conn.RemoteAddr().String(), conn.LocalAddr()) } closed := false defer func() { if debug { debug.Printf("closed pipe %s<->%s\n", conn.RemoteAddr(), host) } connCnt-- if !closed { conn.Close() } }() host, ota, err := getRequest(conn, auth) if err != nil { log.Println("error getting request", conn.RemoteAddr(), conn.LocalAddr(), err) return } debug.Println("connecting", host) remote, err := net.Dial("tcp", host) if err != nil { if ne, ok := err.(*net.OpError); ok && (ne.Err == syscall.EMFILE || ne.Err == syscall.ENFILE) { // log too many open file error // EMFILE is process reaches open file limits, ENFILE is system limit log.Println("dial error:", err) } else { log.Println("error connecting to:", host, err) } return } defer func() { if !closed { remote.Close() } }() if debug { debug.Printf("piping %s<->%s ota=%v connOta=%v", conn.RemoteAddr(), host, ota, conn.IsOta()) } if ota { go ss.PipeThenCloseOta(conn, remote) } else { go ss.PipeThenClose(conn, remote) } ss.PipeThenClose(remote, conn) closed = true return }