/
proxy.go
133 lines (118 loc) · 4.13 KB
/
proxy.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
package main
import (
"flag"
"fmt"
"log"
"net"
"time"
)
var fromHost = flag.String("from", "localhost:80", "The proxy server's host.")
var toHost = flag.String("to", "localhost:8000", "The host that the proxy " +
" server should forward requests to.")
var maxConnections = flag.Int("c", 25, "The maximum number of active " +
"connection at any given time.")
var maxWaitingConnections = flag.Int("cw", 10000, "The maximum number of " +
"connections that can be waiting to be served.")
func main() {
// Parse the command-line arguments.
flag.Parse()
fmt.Printf("Proxying %s->%s.\r\n", *fromHost, *toHost)
// Set up our listening server
server, err := net.Listen("tcp", *fromHost)
// If any error occurs while setting up our listening server, error out.
if err != nil {
log.Fatal(err)
}
// The channel of connections which are waiting to be processed.
waiting := make(chan net.Conn, *maxWaitingConnections)
// The booleans representing the free active connection spaces.
spaces := make(chan bool, *maxConnections)
// Initialize the spaces
for i := 0; i < *maxConnections; i++ {
spaces <- true
}
// Start the connection matcher.
go matchConnections(waiting, spaces)
// Loop indefinitely, accepting connections and handling them.
for {
connection, err := server.Accept()
if err != nil {
// Log the error.
log.Print(err)
} else {
// Create a goroutine to handle the conn
log.Printf("Received connection from %s.\r\n",
connection.RemoteAddr())
waiting <- connection
}
}
}
func matchConnections(waiting chan net.Conn, spaces chan bool) {
// Iterate over each connection in the waiting channel
for connection := range waiting {
// Block until we have a space.
<-spaces
// Create a new goroutine which will call the connection handler and
// then free up the space.
go func(connection net.Conn) {
handleConnection(connection)
spaces <- true
log.Printf("Closed connection from %s.\r\n", connection.RemoteAddr())
}(connection)
}
}
func handleConnection(connection net.Conn) {
// Always close our connection.
defer connection.Close()
// Try to connect to remote server.
remote, err := net.Dial("tcp", *toHost)
if err != nil {
// Exit out when an error occurs
log.Print(err)
return
}
defer remote.Close()
// Create our channel which waits for completion, and our two channels to
// signal that a goroutine is done.
complete := make(chan bool, 2)
ch1 := make(chan bool, 1)
ch2 := make(chan bool, 1)
go copyContent(connection, remote, complete, ch1, ch2)
go copyContent(remote, connection, complete, ch2, ch1)
// Block until we've completed both goroutines!
<- complete
<- complete
}
func copyContent(from net.Conn, to net.Conn, complete chan bool, done chan bool, otherDone chan bool) {
var err error = nil
var bytes []byte = make([]byte, 256)
var read int = 0
for {
select {
// If we received a done message from the other goroutine, we exit.
case <- otherDone:
complete <- true
return
default:
// Read data from the source connection.
from.SetReadDeadline(time.Now().Add(time.Second * 5))
read, err = from.Read(bytes)
// If any errors occured, write to complete as we are done (one of the
// connections closed.)
if err != nil {
complete <- true
done <- true
return
}
// Write data to the destination.
to.SetWriteDeadline(time.Now().Add(time.Second * 5))
_, err = to.Write(bytes[:read])
// Same error checking.
if err != nil {
complete <- true
done <- true
return
}
}
}
}