Skip to content

Commit

Permalink
Realip (#21)
Browse files Browse the repository at this point in the history
* add realip middleware

* clarify comments

* add more tests

* fix package name

* move ip get to realip package, update docs

* lint: missing package comment and drop unused func
  • Loading branch information
umputun authored Feb 12, 2022
1 parent fe82f24 commit 5647a92
Show file tree
Hide file tree
Showing 6 changed files with 240 additions and 17 deletions.
18 changes: 17 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,21 @@ Sets headers (passed as key:value) to requests. I.e. `rest.Headers("Server:MySer

Compresses response with gzip.

## RealIP middleware

RealIP is a middleware that sets a http.Request's RemoteAddr to the results of parsing either the X-Forwarded-For or X-Real-IP headers.

## Maybe middleware

Maybe middleware will allow you to change the flow of the middleware stack execution depending on return
value of maybeFn(request). This is useful for example if you'd like to skip a middleware handler if
a request does not satisfy the maybeFn logic.

## Headers middleware

Headers middleware adds headers to request


## Helpers

- `rest.Wrap` - converts a list of middlewares to nested handlers calls (in reverse order)
Expand All @@ -133,5 +148,6 @@ Profiler is a convenient subrouter used for mounting net/http/pprof, i.e.
return r
}
```
It exposes a whole bunch of `/pprof/*` endpoints as well as `/vars`. Builtin support for `onlyIps` allows to restrict access, which is important if it runs on a publicly exposed port. However, counting on IP check only is not that reliable way to limit request and for production use it would be better to add some sort of auth (for example provided `BasicAuth` middleware) or run with a separate http server, exposed to internal ip/port only.

It exposes a bunch of `/pprof/*` endpoints as well as `/vars`. Builtin support for `onlyIps` allows restricting access, which is important if it runs on a publicly exposed port. However, counting on IP check only is not that reliable way to limit request and for production use it would be better to add some sort of auth (for example provided `BasicAuth` middleware) or run with a separate http server, exposed to internal ip/port only.

22 changes: 7 additions & 15 deletions logger/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import (
"strconv"
"strings"
"time"

"github.com/go-pkgz/rest/realip"
)

// Middleware is a logger for rest requests.
Expand Down Expand Up @@ -103,7 +105,10 @@ func (l *Middleware) Handler(next http.Handler) http.Handler {
rawurl = unescURL
}

remoteIP := l.remoteIP(r)
remoteIP, err := realip.Get(r)
if err != nil {
remoteIP = "unknown ip"
}
if l.ipFn != nil { // mask ip with ipFn
remoteIP = l.ipFn(remoteIP)
}
Expand Down Expand Up @@ -169,7 +174,7 @@ func (l *Middleware) formatDefault(r *http.Request, p *logParts) string {
}

// 127.0.0.1 - frank [10/Oct/2000:13:55:36 -0700] "GET /apache_pb.gif HTTP/1.0" 200 2326 "http://www.example.com/start.html" "Mozilla/4.08 [en] (Win98; I ;Nav)"
//nolint gosec
// nolint gosec
func (l *Middleware) formatApacheCombined(r *http.Request, p *logParts) string {
username := "-"
if p.user != "" {
Expand Down Expand Up @@ -291,19 +296,6 @@ func (l *Middleware) sanitizeQuery(rawQuery string) string {
return query.Encode()
}

// remoteIP gets address from X-Forwarded-For and than from request's remote address
func (l *Middleware) remoteIP(r *http.Request) (remoteIP string) {

if remoteIP = r.Header.Get("X-Forwarded-For"); remoteIP == "" {
remoteIP = r.RemoteAddr
}
remoteIP = strings.Split(remoteIP, ":")[0]
if strings.HasPrefix(remoteIP, "[") {
remoteIP = strings.Split(remoteIP, "]:")[0] + "]"
}
return remoteIP
}

// customResponseWriter is an HTTP response logger that keeps HTTP status code and
// the number of bytes written.
// It implements http.ResponseWriter, http.Flusher and http.Hijacker.
Expand Down
21 changes: 20 additions & 1 deletion middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"strings"

"github.com/go-pkgz/rest/logger"
"github.com/go-pkgz/rest/realip"
)

// Wrap converts a list of middlewares to nested calls (in reverse order)
Expand Down Expand Up @@ -93,7 +94,7 @@ func Headers(headers ...string) func(http.Handler) http.Handler {

// Maybe middleware will allow you to change the flow of the middleware stack execution depending on return
// value of maybeFn(request). This is useful for example if you'd like to skip a middleware handler if
// a request does not satisfied the maybeFn logic.
// a request does not satisfy the maybeFn logic.
// borrowed from https://github.com/go-chi/chi/blob/master/middleware/maybe.go
func Maybe(mw func(http.Handler) http.Handler, maybeFn func(r *http.Request) bool) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
Expand All @@ -106,3 +107,21 @@ func Maybe(mw func(http.Handler) http.Handler, maybeFn func(r *http.Request) boo
})
}
}

// RealIP is a middleware that sets a http.Request's RemoteAddr to the results
// of parsing either the X-Forwarded-For or X-Real-IP headers.
//
// This middleware should only be used if user can trust the headers sent with request.
// If reverse proxies are configured to pass along arbitrary header values from the client,
// or if this middleware used without a reverse proxy, malicious clients could set anything
// as X-Forwarded-For header and attack the server in various ways.
func RealIP(h http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
if rip, err := realip.Get(r); err == nil {
r.RemoteAddr = rip
}
h.ServeHTTP(w, r)
}

return http.HandlerFunc(fn)
}
22 changes: 22 additions & 0 deletions middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ import (
"os"
"sync/atomic"
"testing"
"time"

"github.com/go-pkgz/rest/realip"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -173,6 +175,26 @@ func TestMaybe(t *testing.T) {
}
}

func TestRealIP(t *testing.T) {

handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("%v", r)
require.Equal(t, "1.2.3.4", r.RemoteAddr)
adr, err := realip.Get(r)
require.NoError(t, err)
assert.Equal(t, "1.2.3.4", adr)
})

ts := httptest.NewServer(RealIP(handler))

req, err := http.NewRequest("GET", ts.URL+"/something", http.NoBody)
require.NoError(t, err)
client := http.Client{Timeout: time.Second}
req.Header.Add("X-Real-IP", "1.2.3.4")
_, err = client.Do(req)
require.NoError(t, err)
}

type mockLgr struct {
buf bytes.Buffer
}
Expand Down
80 changes: 80 additions & 0 deletions realip/real.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Package realip extracts a real IP address from the request.
package realip

import (
"bytes"
"fmt"
"net"
"net/http"
"strings"
)

type ipRange struct {
start net.IP
end net.IP
}

var privateRanges = []ipRange{
{start: net.ParseIP("10.0.0.0"), end: net.ParseIP("10.255.255.255")},
{start: net.ParseIP("100.64.0.0"), end: net.ParseIP("100.127.255.255")},
{start: net.ParseIP("172.16.0.0"), end: net.ParseIP("172.31.255.255")},
{start: net.ParseIP("192.0.0.0"), end: net.ParseIP("192.0.0.255")},
{start: net.ParseIP("192.168.0.0"), end: net.ParseIP("192.168.255.255")},
{start: net.ParseIP("198.18.0.0"), end: net.ParseIP("198.19.255.255")},
}

// Get returns real ip from the given request
func Get(r *http.Request) (string, error) {

for _, h := range []string{"X-Forwarded-For", "X-Real-Ip"} {
addresses := strings.Split(r.Header.Get(h), ",")
// march from right to left until we get a public address
// that will be the address right before our proxy.
for i := len(addresses) - 1; i >= 0; i-- {
ip := strings.TrimSpace(addresses[i])
realIP := net.ParseIP(ip)
if !realIP.IsGlobalUnicast() || isPrivateSubnet(realIP) {
continue
}
return ip, nil
}
}

// X-Forwarded-For header set but parsing failed above
if r.Header.Get("X-Forwarded-For") != "" {
return "", fmt.Errorf("no valid ip found")
}

// get IP from RemoteAddr
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return "", fmt.Errorf("can't parse ip %q: %w", r.RemoteAddr, err)
}
if netIP := net.ParseIP(ip); netIP == nil {
return "", fmt.Errorf("no valid ip found")
}

return ip, nil
}

// inRange - check to see if a given ip address is within a range given
func inRange(r ipRange, ipAddress net.IP) bool {
// strcmp type byte comparison
if bytes.Compare(ipAddress, r.start) >= 0 && bytes.Compare(ipAddress, r.end) < 0 {
return true
}
return false
}

// isPrivateSubnet - check to see if this ip is in a private subnet
func isPrivateSubnet(ipAddress net.IP) bool {
if ipCheck := ipAddress.To4(); ipCheck != nil {
for _, r := range privateRanges {
// check if this ip is in a private range
if inRange(r, ipAddress) {
return true
}
}
}
return false
}
94 changes: 94 additions & 0 deletions realip/real_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package realip

import (
"log"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestGetFromHeaders(t *testing.T) {
{
req, err := http.NewRequest("GET", "/something", http.NoBody)
assert.NoError(t, err)
req.Header.Add("Something", "1234567")
req.Header.Add("X-Real-IP", "8.8.8.8")
adr, err := Get(req)
require.NoError(t, err)
assert.Equal(t, "8.8.8.8", adr)
}
{
req, err := http.NewRequest("GET", "/something", http.NoBody)
assert.NoError(t, err)
req.Header.Add("Something", "1234567")
req.Header.Add("X-Forwarded-For", "8.8.8.8,1.1.1.2, 30.30.30.1")
adr, err := Get(req)
require.NoError(t, err)
assert.Equal(t, "30.30.30.1", adr)
}
{
req, err := http.NewRequest("GET", "/something", http.NoBody)
assert.NoError(t, err)
req.Header.Add("Something", "1234567")
req.Header.Add("X-Forwarded-For", "8.8.8.8,1.1.1.2,192.168.1.1,10.0.0.65")
adr, err := Get(req)
require.NoError(t, err)
assert.Equal(t, "1.1.1.2", adr)
}
{
req, err := http.NewRequest("GET", "/something", http.NoBody)
assert.NoError(t, err)
req.Header.Add("Something", "1234567")
req.Header.Add("X-Forwarded-For", "30.30.30.1")
req.Header.Add("X-Real-Ip", "10.0.0.1")
adr, err := Get(req)
require.NoError(t, err)
assert.Equal(t, "30.30.30.1", adr)
}
{
req, err := http.NewRequest("GET", "/something", http.NoBody)
assert.NoError(t, err)
req.Header.Add("Something", "1234567")
req.Header.Add("X-Forwarded-For", "30.30.30.1")
req.Header.Add("X-Real-Ip", "8.8.8.8")
adr, err := Get(req)
require.NoError(t, err)
assert.Equal(t, "30.30.30.1", adr)
}
{
req, err := http.NewRequest("GET", "/something", http.NoBody)
assert.NoError(t, err)
req.Header.Add("Something", "1234567")
req.Header.Add("X-Forwarded-For", "10.0.0.2,192.168.1.1")
req.Header.Add("X-Real-Ip", "8.8.8.8")
adr, err := Get(req)
require.NoError(t, err)
assert.Equal(t, "8.8.8.8", adr)
}
{
req, err := http.NewRequest("GET", "/something", http.NoBody)
assert.NoError(t, err)
ip, err := Get(req)
assert.Error(t, err)
assert.Equal(t, "", ip)
}
}

func TestGetFromRemoteAddr(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Printf("%v", r)
adr, err := Get(r)
require.NoError(t, err)
assert.Equal(t, "127.0.0.1", adr)
}))

req, err := http.NewRequest("GET", ts.URL+"/something", http.NoBody)
require.NoError(t, err)
client := http.Client{Timeout: time.Second}
_, err = client.Do(req)
require.NoError(t, err)
}

0 comments on commit 5647a92

Please sign in to comment.