forked from samuong/alpaca
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathproxy_test.go
235 lines (211 loc) · 7.6 KB
/
proxy_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
package main
import (
"bufio"
"crypto/tls"
"crypto/x509"
"fmt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"io"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"net/url"
"testing"
)
type testServer struct {
requests chan<- string
}
func (ts testServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
ts.requests <- fmt.Sprintf("%s to server", req.Method)
w.WriteHeader(http.StatusOK)
fmt.Fprintln(w, "Hello, client")
}
type testProxy struct {
requests chan<- string
name string
delegate http.Handler
}
func (tp testProxy) ServeHTTP(w http.ResponseWriter, req *http.Request) {
tp.requests <- fmt.Sprintf("%s to %s", req.Method, tp.name)
tp.delegate.ServeHTTP(w, req)
}
func newDirectProxy() ProxyHandler {
return NewProxyHandler(func(r *http.Request) (*url.URL, error) { return nil, nil }, nil)
}
func newChildProxy(parent *httptest.Server) ProxyHandler {
return NewProxyHandler(func(r *http.Request) (*url.URL, error) {
return &url.URL{Host: parent.Listener.Addr().String()}, nil
}, nil)
}
func proxyServer(t *testing.T, proxy *httptest.Server) proxyFunc {
u, err := url.Parse(proxy.URL)
require.Nil(t, err)
return http.ProxyURL(u)
}
func tlsConfig(server *httptest.Server) *tls.Config {
cp := x509.NewCertPool()
cp.AddCert(server.Certificate())
return &tls.Config{RootCAs: cp}
}
func testGetRequest(t *testing.T, tr *http.Transport, serverURL string) {
client := http.Client{Transport: tr}
resp, err := client.Get(serverURL)
require.Nil(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
buf, err := ioutil.ReadAll(resp.Body)
require.Nil(t, err)
assert.Equal(t, "Hello, client\n", string(buf))
}
func TestGetViaProxy(t *testing.T) {
requests := make(chan string, 2)
server := httptest.NewServer(testServer{requests})
defer server.Close()
proxy := httptest.NewServer(testProxy{requests, "proxy", newDirectProxy()})
defer proxy.Close()
tr := &http.Transport{Proxy: proxyServer(t, proxy)}
testGetRequest(t, tr, server.URL)
require.Len(t, requests, 2)
assert.Equal(t, "GET to proxy", <-requests)
assert.Equal(t, "GET to server", <-requests)
}
func TestGetOverTlsViaProxy(t *testing.T) {
requests := make(chan string, 2)
server := httptest.NewTLSServer(testServer{requests})
defer server.Close()
proxy := httptest.NewServer(testProxy{requests, "proxy", newDirectProxy()})
defer proxy.Close()
tr := &http.Transport{Proxy: proxyServer(t, proxy), TLSClientConfig: tlsConfig(server)}
testGetRequest(t, tr, server.URL)
require.Len(t, requests, 2)
assert.Equal(t, "CONNECT to proxy", <-requests)
assert.Equal(t, "GET to server", <-requests)
}
func TestGetViaTwoProxies(t *testing.T) {
requests := make(chan string, 3)
server := httptest.NewServer(testServer{requests})
defer server.Close()
parent := httptest.NewServer(testProxy{requests, "parent proxy", newDirectProxy()})
defer parent.Close()
child := httptest.NewServer(testProxy{requests, "child proxy", newChildProxy(parent)})
defer child.Close()
tr := &http.Transport{Proxy: proxyServer(t, child)}
testGetRequest(t, tr, server.URL)
require.Len(t, requests, 3)
assert.Equal(t, "GET to child proxy", <-requests)
assert.Equal(t, "GET to parent proxy", <-requests)
assert.Equal(t, "GET to server", <-requests)
}
func TestGetOverTlsViaTwoProxies(t *testing.T) {
requests := make(chan string, 3)
server := httptest.NewTLSServer(testServer{requests})
defer server.Close()
parent := httptest.NewServer(testProxy{requests, "parent proxy", newDirectProxy()})
defer parent.Close()
child := httptest.NewServer(testProxy{requests, "child proxy", newChildProxy(parent)})
defer child.Close()
tr := &http.Transport{Proxy: proxyServer(t, child), TLSClientConfig: tlsConfig(server)}
testGetRequest(t, tr, server.URL)
require.Len(t, requests, 3)
assert.Equal(t, "CONNECT to child proxy", <-requests)
assert.Equal(t, "CONNECT to parent proxy", <-requests)
assert.Equal(t, "GET to server", <-requests)
}
type hopByHopTestServer struct {
t *testing.T
}
func (s hopByHopTestServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
assert.NotContains(s.t, req.Header, "Connection")
assert.NotContains(s.t, req.Header, "Proxy-Authorization")
assert.Contains(s.t, req.Header, "Authorization")
assert.NotContains(s.t, req.Header, "X-Alpaca-Request")
w.Header().Set("Connection", "X-Alpaca-Response")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("X-Alpaca-Response", "this should get dropped")
w.WriteHeader(http.StatusOK)
}
func testHopByHopHeaders(t *testing.T, method, url string, proxy proxyFunc) {
req, err := http.NewRequest(method, url, nil)
require.Nil(t, err)
req.Header.Set("Connection", "X-Alpaca-Request")
req.Header.Set("Proxy-Authorization", "Basic bWFsb3J5YXJjaGVyOmd1ZXN0")
req.Header.Set("Authorization", "Basic bmlrb2xhaWpha292Omd1ZXN0")
req.Header.Set("X-Alpaca-Request", "this should get dropped")
tr := &http.Transport{Proxy: proxy}
resp, err := tr.RoundTrip(req)
require.Nil(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
assert.NotContains(t, resp.Header, "Connection")
assert.Contains(t, resp.Header, "Cache-Control")
assert.NotContains(t, resp.Header, "X-Alpaca-Response")
}
func TestHopByHopHeaders(t *testing.T) {
server := httptest.NewServer(hopByHopTestServer{t})
defer server.Close()
proxy := httptest.NewServer(newDirectProxy())
defer proxy.Close()
testHopByHopHeaders(t, http.MethodGet, server.URL, proxyServer(t, proxy))
}
func TestHopByHopHeadersForConnectRequest(t *testing.T) {
parent := httptest.NewServer(hopByHopTestServer{t})
defer parent.Close()
child := httptest.NewServer(newChildProxy(parent))
defer child.Close()
testHopByHopHeaders(t, http.MethodConnect, parent.URL, proxyServer(t, child))
}
func TestDeleteConnectionTokens(t *testing.T) {
header := make(http.Header)
header.Add("Connection", "close")
header.Add("Connection", "x-alpaca-1, x-alpaca-2")
header.Set("X-Alpaca-1", "this should get dropped")
header.Set("X-Alpaca-2", "this should get dropped")
header.Set("X-Alpaca-3", "this should NOT get dropped")
deleteConnectionTokens(header)
assert.NotContains(t, header, "X-Alpaca-1")
assert.NotContains(t, header, "X-Alpaca-2")
assert.Contains(t, header, "X-Alpaca-3")
}
func TestCloseFromOneSideResultsInEOFOnOtherSide(t *testing.T) {
closeConnection := func(conn net.Conn) {
conn.Close()
}
assertEOF := func(conn net.Conn) {
_, err := bufio.NewReader(conn).Peek(1)
assert.Equal(t, io.EOF, err)
}
testProxyTunnel(t, closeConnection, assertEOF)
testProxyTunnel(t, assertEOF, closeConnection)
}
func testProxyTunnel(t *testing.T, onServer, onClient func(conn net.Conn)) {
// Set up a Listener to act as a server, which we'll connect to via the proxy.
server, err := net.Listen("tcp", "localhost:0")
require.Nil(t, err)
defer server.Close()
proxy := httptest.NewServer(newDirectProxy())
defer proxy.Close()
client, err := net.Dial("tcp", proxy.Listener.Addr().String())
require.Nil(t, err)
defer client.Close()
// The server just accepts a connection and calls the callback.
done := make(chan struct{})
go func() {
defer close(done)
conn, err := server.Accept()
require.Nil(t, err)
onServer(conn)
}()
// Connect to the server via the proxy, using a CONNECT request.
serverURL := url.URL{Host: server.Addr().String()}
req, err := http.NewRequest(http.MethodConnect, serverURL.String(), nil)
require.Nil(t, err)
require.Nil(t, req.Write(client))
resp, err := http.ReadResponse(bufio.NewReader(client), req)
require.Nil(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
// Call the client callback, and then make sure that the server is done before finishing.
onClient(client)
<-done
}