/
tcp-proxy.go
113 lines (100 loc) · 2.06 KB
/
tcp-proxy.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
package proxy
import (
"fmt"
"io"
"os"
"net"
)
var connid = uint64(0)
//A Proxy represents a pair of connections and their state
type Proxy struct {
sentBytes uint64
receivedBytes uint64
laddr, raddr *net.TCPAddr
lconn, rconn *net.TCPConn
erred bool
errsig chan bool
}
func NewProxy(localAddr, remoteAddr string) *Proxy {
laddr, err := net.ResolveTCPAddr("tcp", localAddr)
check(err)
raddr, err := net.ResolveTCPAddr("tcp", remoteAddr)
check(err)
listener, err := net.ListenTCP("tcp", laddr)
check(err)
conn, err := listener.AcceptTCP()
if err != nil {
fmt.Printf("Failed to accept connection '%s'\n", err)
}
return &Proxy{
lconn: conn,
laddr: laddr,
raddr: raddr,
erred: false,
errsig: make(chan bool),
}
}
func (p *Proxy) err(s string, err error) {
if p.erred {
return
}
if err != io.EOF {
log(s, err)
}
p.errsig <- true
p.erred = true
}
func (p *Proxy) Start() {
defer p.lconn.Close()
//connect to remote
rconn, err := net.DialTCP("tcp", nil, p.raddr)
if err != nil {
p.err("Remote connection failed: %s", err)
return
}
p.rconn = rconn
defer p.rconn.Close()
//display both ends
// p.log("Opened %s >>> %s", p.lconn.RemoteAddr().String(), p.rconn.RemoteAddr().String())
//bidirectional copy
go p.pipe(p.lconn, p.rconn)
go p.pipe(p.rconn, p.lconn)
//wait for close...
<-p.errsig
// p.log("Closed (%d bytes sent, %d bytes recieved)", p.sentBytes, p.receivedBytes)
}
func (p *Proxy) pipe(src, dst *net.TCPConn) {
//data direction
islocal := src == p.lconn
//directional copy (64k buffer)
buff := make([]byte, 0xffff)
for {
n, err := src.Read(buff)
if err != nil {
p.err("Read failed '%s'\n", err)
return
}
b := buff[:n]
//show output
n, err = dst.Write(b)
if err != nil {
p.err("Write failed '%s'\n", err)
return
}
if islocal {
p.sentBytes += uint64(n)
} else {
p.receivedBytes += uint64(n)
}
}
}
//helper functions
func check(err error) {
if err != nil {
log(err.Error())
os.Exit(1)
}
}
func log(f string, args ...interface{}) {
fmt.Printf(f+"\n", args...)
}