/
gofriday.go
132 lines (108 loc) · 2.55 KB
/
gofriday.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
package main
import (
"crypto/tls"
"crypto/x509"
"flag"
"fmt"
"io"
"io/ioutil"
"log"
"net/http"
"net/url"
"os"
"strconv"
)
var port = flag.Int("port", 0, "HTTP port to listen to (default to $PORT)")
var ca = flag.String("ca", "", "CA bundle to validate remote server (PEM file)")
// Implements http.Handle
type ReverseProxy struct {
// The remote that we need to connect to
Remote *url.URL
// The client to which requests must be proxied
Transport *http.Transport
}
func (rp *ReverseProxy) ServeHTTP(out http.ResponseWriter, req *http.Request) {
c := http.Client{
Transport: rp.Transport,
}
// Only change Path component of the remote URL
proxyurl := *rp.Remote
proxyurl.Path = req.URL.Path
// Prepare a request which is identical to the original one
proxyreq := &http.Request{
Method: req.Method,
URL: &proxyurl,
Header: req.Header,
Body: req.Body,
ContentLength: req.ContentLength,
TransferEncoding: req.TransferEncoding,
Close: false,
}
resp, err := c.Do(proxyreq)
if err != nil {
log.Println("error proxying request", err)
log.Println("request", req)
log.Println("response", resp)
out.WriteHeader(http.StatusBadGateway)
return
}
// Send response header back to client
for k, v := range resp.Header {
out.Header()[k] = v
}
out.WriteHeader(resp.StatusCode)
_, err = io.Copy(out, resp.Body)
if err != nil {
log.Println("error sending response body", err)
}
resp.Body.Close()
}
func main() {
flag.Parse()
if *port == 0 {
xport, err := strconv.Atoi(os.Getenv("PORT"))
if err != nil {
fmt.Println("Please specify the HTTP port (either flag or environment)")
os.Exit(1)
}
*port = xport
}
if flag.NArg() < 1 {
fmt.Println("Specify remote URL on the command line")
os.Exit(1)
}
remote, err := url.Parse(flag.Arg(0))
if err != nil {
fmt.Println("error parsing remote URL", err)
os.Exit(1)
}
transport := new(http.Transport)
switch remote.Scheme {
case "http":
if *ca != "" {
log.Println("ignoring ca flag for non-https remote")
}
case "https":
if *ca != "" {
pool := x509.NewCertPool()
data, err := ioutil.ReadFile(*ca)
if err != nil {
log.Fatal(err)
os.Exit(1)
}
pool.AppendCertsFromPEM(data)
tlsconfig := new(tls.Config)
tlsconfig.RootCAs = pool
transport.TLSClientConfig = tlsconfig
}
default:
fmt.Println("unsupported remote scheme:", remote.Scheme)
os.Exit(1)
}
rp := &ReverseProxy{
remote,
transport,
}
log.Fatal(http.ListenAndServe(
fmt.Sprintf(":%v", *port), rp))
}