Skip to content
This repository has been archived by the owner on Mar 5, 2024. It is now read-only.

Commit

Permalink
Always proxy through requests to get a session token
Browse files Browse the repository at this point in the history
(Regardless of whether the path is in the whitelist or not)
  • Loading branch information
rbvigilante committed Mar 3, 2020
1 parent 07d49e6 commit bb113f1
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 8 deletions.
6 changes: 5 additions & 1 deletion pkg/aws/metadata/handler_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ type proxyHandler struct {
whitelistRouteRegexp *regexp.Regexp
}

var tokenRouteRegexp = regexp.MustCompile("^/?[^/]+/api/token$")

func (p *proxyHandler) Install(router *mux.Router) {
router.PathPrefix("/").Handler(adapt(withMeter("proxy", p)))
}
Expand All @@ -42,7 +44,9 @@ func (w *teeWriter) WriteHeader(statusCode int) {
}

func (p *proxyHandler) Handle(ctx context.Context, w http.ResponseWriter, r *http.Request) (int, error) {
if p.whitelistRouteRegexp.MatchString(r.URL.Path) {
if p.whitelistRouteRegexp.MatchString(r.URL.Path) ||
// Always proxy through requests to pick up a session token
(r.Method == http.MethodPut && tokenRouteRegexp.MatchString(r.URL.Path)) {
writer := &teeWriter{w, http.StatusOK}
// Passing the request through with no RemoteAddr prevents the backing service adding an X-Forwarded-For header.
// This is important, because v2 of the EC2 Instance Metadata API blocks all requests containing such a header
Expand Down
27 changes: 20 additions & 7 deletions pkg/aws/metadata/handler_proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import (
"github.com/prometheus/client_golang/prometheus"
)

func performRequest(allowed, path string, returnCode int) (int, *httptest.ResponseRecorder) {
func performRequest(allowed, path string, method string, returnCode int) (int, *httptest.ResponseRecorder) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()

Expand All @@ -40,7 +40,7 @@ func performRequest(allowed, path string, returnCode int) (int, *httptest.Respon
router := mux.NewRouter()
handler.Install(router)

r, _ := http.NewRequest("GET", path, nil)
r, _ := http.NewRequest(method, path, nil)
rr := httptest.NewRecorder()

router.ServeHTTP(rr, r.WithContext(ctx))
Expand All @@ -51,7 +51,7 @@ func performRequest(allowed, path string, returnCode int) (int, *httptest.Respon
func TestProxyDefaultBlacklistingRoot(t *testing.T) {
defer leaktest.Check(t)()

hits, rr := performRequest("", "/", http.StatusOK)
hits, rr := performRequest("", "/", "GET", http.StatusOK)

if hits != 0 {
t.Error("unexpected reverse proxy hit")
Expand Down Expand Up @@ -84,7 +84,7 @@ func TestProxyFiltering(t *testing.T) {

requestsInitial := readPrometheusCounterValue("kiam_metadata_responses_total", "handler", "proxy")
blockedInitial := readPrometheusSimpleCounterValue("kiam_metadata_proxy_requests_blocked_total")
hits, rr := performRequest("foo.*", "/bar", http.StatusOK)
hits, rr := performRequest("foo.*", "/bar", "GET", http.StatusOK)

if hits != 0 {
t.Error("unexpected reverse proxy hit")
Expand All @@ -106,10 +106,23 @@ func TestProxyFiltering(t *testing.T) {
}
}

func TestTokenRoute(t *testing.T) {
defer leaktest.Check(t)()

hits, rr := performRequest("foo.*", "/latest/api/token", "PUT", http.StatusOK)

if hits != 1 {
t.Error("expected reverse proxy hit")
}
if rr.Code != http.StatusOK {
t.Error("unexpected status", rr.Code)
}
}

func TestProxyFilteringSubpath(t *testing.T) {
defer leaktest.Check(t)()

hits, rr := performRequest("foo.*", "/bar/baz", http.StatusOK)
hits, rr := performRequest("foo.*", "/bar/baz", "GET", http.StatusOK)

if hits != 0 {
t.Error("unexpected reverse proxy hit")
Expand All @@ -125,7 +138,7 @@ func TestProxyFilteringSubpath(t *testing.T) {
func TestProxyWhitelisting(t *testing.T) {
defer leaktest.Check(t)()

hits, rr := performRequest("foo.*", "/foo", http.StatusOK)
hits, rr := performRequest("foo.*", "/foo", "GET", http.StatusOK)

if hits != 1 {
t.Error("expected reverse proxy hit")
Expand All @@ -138,7 +151,7 @@ func TestProxyWhitelisting(t *testing.T) {
func TestErrorReturned(t *testing.T) {
defer leaktest.Check(t)()

hits, rr := performRequest("foo.*", "/foo", http.StatusForbidden)
hits, rr := performRequest("foo.*", "/foo", "GET", http.StatusForbidden)

if hits != 1 {
t.Error("expected reverse proxy hit")
Expand Down

0 comments on commit bb113f1

Please sign in to comment.