Exemple #1
0
func copy(dest *bufio.ReadWriter, src *bufio.ReadWriter) {
	buf := make([]byte, 40*1024)
	for {
		n, err := src.Read(buf)
		if err != nil && err != io.EOF {
			return
		}
		if n == 0 {
			return
		}
		dest.Write(buf[0:n])
		dest.Flush()
	}
}
Exemple #2
0
func Copy(dest *bufio.ReadWriter, src *bufio.ReadWriter) {
	buf := make([]byte, 40*1024)
	for {
		n, err := src.Read(buf)
		if err != nil && err != io.EOF {
			log.Printf("Read failed: %v", err)
			return
		}
		if n == 0 {
			return
		}
		dest.Write(buf[0:n])
		dest.Flush()
	}
}
// dirty but correct code for testing if the tcp connection is alive as Russ Cox said : https://groups.google.com/d/msg/golang-nuts/oaKW4WMTdK8/3qiR2Mvn43kJ
func aliveConn(conn net.Conn, bufrw *bufio.ReadWriter) (open bool) {

	buf := make([]byte, 1)
	err := conn.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
	if err != nil {
		l.Println("unable to set deadline", err)
	}
	n, err := bufrw.Read(buf)

	if err != nil {
		if err == io.EOF {
			return false
		} else {
			return true
		}
	}

	l.Println("read :", n, buf[:n])

	return true
}
Exemple #4
0
// Copy from src buffer to destination buffer. One way.
func Copy(dest *bufio.ReadWriter, src *bufio.ReadWriter, route **Route) {
	buf := make([]byte, 40*1024)
	for {
		n, err := src.Read(buf)
		if err != nil && err != io.EOF {
			//log.Error("Read failed: %v", err)
			return
		}
		if n == 0 {
			return
		}
		(*route).Seen()
		_, err = dest.Write(buf[0:n])
		if err != nil && err != io.EOF {
			log.Warning("Could not write to dest", err)
		}
		err = dest.Flush()

		if err != nil && err != io.EOF {
			log.Warning("Could not flush to dest", err)
		}
	}
}
func bufferedCopy(dest *bufio.ReadWriter, src *bufio.ReadWriter) {
	buf := make([]byte, 40*1024)
	for {
		n, err := src.Read(buf)
		if err != nil && err != io.EOF {
			log.Printf("Upstream read failed: %v", err)
			return
		}
		if n == 0 {
			return
		}
		n, err = dest.Write(buf[0:n])
		if err != nil && err != io.EOF {
			log.Printf("Downstream write failed: %v", err)
			return
		}

		err = dest.Flush()
		if err != nil {
			log.Printf("Downstream write flush failed: %v", err)
			return
		}
	}
}
Exemple #6
0
func (p *Proxy) handle(ctx *Context, conn net.Conn, brw *bufio.ReadWriter) error {
	log.Debugf("martian: waiting for request: %v", conn.RemoteAddr())

	req, err := http.ReadRequest(brw.Reader)
	if err != nil {
		if isCloseable(err) {
			log.Debugf("martian: connection closed prematurely: %v", err)
		} else {
			log.Errorf("martian: failed to read request: %v", err)
		}

		// TODO: TCPConn.WriteClose() to avoid sending an RST to the client.

		return errClose
	}
	defer req.Body.Close()

	if h, pattern := p.mux.Handler(req); pattern != "" {
		defer brw.Flush()

		closing := req.Close || p.Closing()

		log.Infof("martian: intercepted configuration request: %s", req.URL)
		rw := newResponseWriter(brw, closing)
		defer rw.Close()

		h.ServeHTTP(rw, req)

		// Call WriteHeader to ensure a response is sent, since the handler isn't
		// required to call WriteHeader/Write.
		rw.WriteHeader(200)

		if closing {
			return errClose
		}

		return nil
	}

	ctx, err = withSession(ctx.Session())
	if err != nil {
		log.Errorf("martian: failed to build new context: %v", err)
		return err
	}

	link(req, ctx)
	defer unlink(req)

	if tconn, ok := conn.(*tls.Conn); ok {
		ctx.Session().MarkSecure()

		cs := tconn.ConnectionState()
		req.TLS = &cs
	}

	req.URL.Scheme = "http"
	if ctx.Session().IsSecure() {
		log.Debugf("martian: forcing HTTPS inside secure session")
		req.URL.Scheme = "https"
	}

	req.RemoteAddr = conn.RemoteAddr().String()
	if req.URL.Host == "" {
		req.URL.Host = req.Host
	}

	log.Infof("martian: received request: %s", req.URL)

	if req.Method == "CONNECT" {
		if err := p.reqmod.ModifyRequest(req); err != nil {
			log.Errorf("martian: error modifying CONNECT request: %v", err)
			proxyutil.Warning(req.Header, err)
		}

		if p.mitm != nil {
			log.Debugf("martian: attempting MITM for connection: %s", req.Host)
			res := proxyutil.NewResponse(200, nil, req)

			if err := p.resmod.ModifyResponse(res); err != nil {
				log.Errorf("martian: error modifying CONNECT response: %v", err)
				proxyutil.Warning(res.Header, err)
			}

			if err := res.Write(brw); err != nil {
				log.Errorf("martian: got error while writing response back to client: %v", err)
			}
			if err := brw.Flush(); err != nil {
				log.Errorf("martian: got error while flushing response back to client: %v", err)
			}

			log.Debugf("martian: completed MITM for connection: %s", req.Host)

			b := make([]byte, 1)
			if _, err := brw.Read(b); err != nil {
				log.Errorf("martian: error peeking message through CONNECT tunnel to determine type: %v", err)
			}

			// Drain all of the rest of the buffered data.
			buf := make([]byte, brw.Reader.Buffered())
			brw.Read(buf)

			// 22 is the TLS handshake.
			// https://tools.ietf.org/html/rfc5246#section-6.2.1
			if b[0] == 22 {
				// Prepend the previously read data to be read again by
				// http.ReadRequest.
				tlsconn := tls.Server(&peekedConn{conn, io.MultiReader(bytes.NewReader(b), bytes.NewReader(buf), conn)}, p.mitm.TLSForHost(req.Host))

				if err := tlsconn.Handshake(); err != nil {
					p.mitm.HandshakeErrorCallback(req, err)
					return err
				}

				brw.Writer.Reset(tlsconn)
				brw.Reader.Reset(tlsconn)

				return p.handle(ctx, tlsconn, brw)
			}

			// Prepend the previously read data to be read again by http.ReadRequest.
			brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), bytes.NewReader(buf), conn))
			return p.handle(ctx, conn, brw)
		}

		log.Debugf("martian: attempting to establish CONNECT tunnel: %s", req.URL.Host)
		res, cconn, cerr := p.connect(req)
		if cerr != nil {
			log.Errorf("martian: failed to CONNECT: %v", err)
			res = proxyutil.NewResponse(502, nil, req)
			proxyutil.Warning(res.Header, cerr)

			if err := p.resmod.ModifyResponse(res); err != nil {
				log.Errorf("martian: error modifying CONNECT response: %v", err)
				proxyutil.Warning(res.Header, err)
			}

			if err := res.Write(brw); err != nil {
				log.Errorf("martian: got error while writing response back to client: %v", err)
			}
			err := brw.Flush()
			if err != nil {
				log.Errorf("martian: got error while flushing response back to client: %v", err)
			}
			return err
		}
		defer res.Body.Close()
		defer cconn.Close()

		if err := p.resmod.ModifyResponse(res); err != nil {
			log.Errorf("martian: error modifying CONNECT response: %v", err)
			proxyutil.Warning(res.Header, err)
		}

		if err := res.Write(brw); err != nil {
			log.Errorf("martian: got error while writing response back to client: %v", err)
		}
		if err := brw.Flush(); err != nil {
			log.Errorf("martian: got error while flushing response back to client: %v", err)
		}

		cbw := bufio.NewWriter(cconn)
		cbr := bufio.NewReader(cconn)
		defer cbw.Flush()

		copySync := func(w io.Writer, r io.Reader, donec chan<- bool) {
			if _, err := io.Copy(w, r); err != nil && err != io.EOF {
				log.Errorf("martian: failed to copy CONNECT tunnel: %v", err)
			}

			log.Debugf("martian: CONNECT tunnel finished copying")
			donec <- true
		}

		donec := make(chan bool, 2)
		go copySync(cbw, brw, donec)
		go copySync(brw, cbr, donec)

		log.Debugf("martian: established CONNECT tunnel, proxying traffic")
		<-donec
		<-donec
		log.Debugf("martian: closed CONNECT tunnel")

		return errClose
	}

	if err := p.reqmod.ModifyRequest(req); err != nil {
		log.Errorf("martian: error modifying request: %v", err)
		proxyutil.Warning(req.Header, err)
	}

	res, err := p.roundTrip(ctx, req)
	if err != nil {
		log.Errorf("martian: failed to round trip: %v", err)
		res = proxyutil.NewResponse(502, nil, req)
		proxyutil.Warning(res.Header, err)
	}
	defer res.Body.Close()

	if err := p.resmod.ModifyResponse(res); err != nil {
		log.Errorf("martian: error modifying response: %v", err)
		proxyutil.Warning(res.Header, err)
	}

	var closing error
	if req.Close || res.Close || p.Closing() {
		log.Debugf("martian: received close request: %v", req.RemoteAddr)
		res.Close = true
		closing = errClose
	}

	log.Debugf("martian: sent response: %v", req.URL)
	if err := res.Write(brw); err != nil {
		log.Errorf("martian: got error while writing response back to client: %v", err)
	}
	if err := brw.Flush(); err != nil {
		log.Errorf("martian: got error while flushing response back to client: %v", err)
	}

	return closing
}
Exemple #7
0
func (s *Server) handleSingle(rw *bufio.ReadWriter) error {
	line, err := rw.ReadBytes('\n')
	if err != nil {
		return err
	}
	args := bytes.Fields(line)
	if len(args) == 0 {
		return nil
	}

	switch {
	case bytes.Equal(args[0], []byte("READ")):
		// READ topic offset limit\n
		if len(args) != 4 {
			fmt.Fprintf(rw, "ERR READ requires 3 arg, got %d\n", len(args)-1)
			return nil
		}
		s.mu.Lock()
		defer s.mu.Unlock()
		stack, ok := s.stacks[string(args[1])]
		if !ok {
			fmt.Fprintf(rw, "ERR stack %s does not exist\n", args[1])
			return nil
		}
		offset, err := strconv.ParseInt(string(args[2]), 10, 64)
		if err != nil || offset < 0 {
			fmt.Fprint(rw, "ERR invalid offset\n")
			return nil
		}
		limit, err := strconv.ParseInt(string(args[3]), 10, 64)
		if err != nil || limit <= 0 {
			fmt.Fprint(rw, "ERR invalid limit\n")
			return nil
		}

		if offset > int64(len(stack)) {
			return nil
		}
		if offset+limit > int64(len(stack)) {
			limit = int64(len(stack)) - offset
		}

		for i, msg := range stack[offset : offset+limit] {
			fmt.Fprintf(rw, "MSG %d %d\n", offset+int64(i), len(msg))
			rw.Write(msg)
			rw.WriteByte('\n')
		}
		rw.WriteString("END\n")

	case bytes.Equal(args[0], []byte("PUSH")):
		// PUSH topic msg-size\n
		// msg\n
		if len(args) != 3 {
			fmt.Fprintf(rw, "ERR PUSH requires 2 args, got %d\n", len(args)-1)
			return nil
		}

		msgsize, err := strconv.ParseInt(string(args[2]), 10, 64)
		if err != nil || msgsize <= 0 {
			fmt.Fprint(rw, "ERR invalid message size\n")
			return nil
		}

		b := make([]byte, msgsize+1)
		if n, err := rw.Read(b); err != nil {
			fmt.Fprintf(rw, "ERR cannot read: %s\n", err)
			return nil
		} else if n != len(b) {
			fmt.Fprint(rw, "ERR incompete message\n")
			return nil
		}
		if b[msgsize] != '\n' {
			fmt.Fprint(rw, "ERR invalid message termination\n")
			return nil
		}

		s.mu.Lock()
		defer s.mu.Unlock()
		s.stacks[string(args[1])] = append(s.stacks[string(args[1])], b[:msgsize])

		fmt.Fprint(rw, "OK\n")

	case bytes.Equal(args[0], []byte("LEN")):
		// LEN topic \n
		if len(args) != 2 {
			fmt.Fprintf(rw, "ERR LEN requires 1 arg, got %d\n", len(args)-1)
			return nil
		}
		s.mu.Lock()
		fmt.Fprintf(rw, "%d\n", len(s.stacks[string(args[1])]))
		defer s.mu.Unlock()
	case bytes.Equal(args[0], []byte("DUMP")):
		s.mu.Lock()
		for name, stack := range s.stacks {
			fmt.Fprintf(rw, "%s %s\n", name, stack)
		}
		rw.WriteString("END\n")
		defer s.mu.Unlock()

	default:
		fmt.Fprintf(rw, "ERR unknown command: %q\n", args)
	}
	return nil
}