diff --git a/realip.go b/realip.go deleted file mode 100644 index 637e88b..0000000 --- a/realip.go +++ /dev/null @@ -1,97 +0,0 @@ -package rest - -import ( - "bytes" - "fmt" - "net" - "net/http" - "strings" -) - -type ipRange struct { - start net.IP - end net.IP -} - -var privateRanges = []ipRange{ - {start: net.ParseIP("10.0.0.0"), end: net.ParseIP("10.255.255.255")}, - {start: net.ParseIP("100.64.0.0"), end: net.ParseIP("100.127.255.255")}, - {start: net.ParseIP("172.16.0.0"), end: net.ParseIP("172.31.255.255")}, - {start: net.ParseIP("192.0.0.0"), end: net.ParseIP("192.0.0.255")}, - {start: net.ParseIP("192.168.0.0"), end: net.ParseIP("192.168.255.255")}, - {start: net.ParseIP("198.18.0.0"), end: net.ParseIP("198.19.255.255")}, -} - -// RealIP is a middleware that sets a http.Request's RemoteAddr to the results -// of parsing either the X-Forwarded-For or X-Real-IP headers. -// -// This middleware should only be used if user can trust the headers sent with request. -// If reverse proxies are configured to pass along arbitrary header values from the client, -// or if this middleware used without a reverse proxy, malicious clients could set anything -// as X-Forwarded-For header and attack the server in various ways. -func RealIP(h http.Handler) http.Handler { - fn := func(w http.ResponseWriter, r *http.Request) { - if rip, err := GetIPAddress(r); err == nil { - r.RemoteAddr = rip - } - h.ServeHTTP(w, r) - } - - return http.HandlerFunc(fn) -} - -// GetIPAddress returns real ip from the given request -func GetIPAddress(r *http.Request) (string, error) { - - for _, h := range []string{"X-Forwarded-For", "X-Real-Ip"} { - addresses := strings.Split(r.Header.Get(h), ",") - // march from right to left until we get a public address - // that will be the address right before our proxy. - for i := len(addresses) - 1; i >= 0; i-- { - ip := strings.TrimSpace(addresses[i]) - realIP := net.ParseIP(ip) - if !realIP.IsGlobalUnicast() || isPrivateSubnet(realIP) { - continue - } - return ip, nil - } - } - - // X-Forwarded-For header set but parsing failed above - if r.Header.Get("X-Forwarded-For") != "" { - return "", fmt.Errorf("no valid ip found") - } - - // get IP from RemoteAddr - ip, _, err := net.SplitHostPort(r.RemoteAddr) - if err != nil { - return "", fmt.Errorf("can't parse ip %q: %w", r.RemoteAddr, err) - } - if netIP := net.ParseIP(ip); netIP == nil { - return "", fmt.Errorf("no valid ip found") - } - - return ip, nil -} - -// inRange - check to see if a given ip address is within a range given -func inRange(r ipRange, ipAddress net.IP) bool { - // strcmp type byte comparison - if bytes.Compare(ipAddress, r.start) >= 0 && bytes.Compare(ipAddress, r.end) < 0 { - return true - } - return false -} - -// isPrivateSubnet - check to see if this ip is in a private subnet -func isPrivateSubnet(ipAddress net.IP) bool { - if ipCheck := ipAddress.To4(); ipCheck != nil { - for _, r := range privateRanges { - // check if this ip is in a private range - if inRange(r, ipAddress) { - return true - } - } - } - return false -} diff --git a/realip_test.go b/realip_test.go deleted file mode 100644 index 708a4e3..0000000 --- a/realip_test.go +++ /dev/null @@ -1,113 +0,0 @@ -package rest - -import ( - "log" - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestGetIPAddressFromHeaders(t *testing.T) { - { - req, err := http.NewRequest("GET", "/something", http.NoBody) - assert.NoError(t, err) - req.Header.Add("Something", "1234567") - req.Header.Add("X-Real-IP", "8.8.8.8") - adr, err := GetIPAddress(req) - require.NoError(t, err) - assert.Equal(t, "8.8.8.8", adr) - } - { - req, err := http.NewRequest("GET", "/something", http.NoBody) - assert.NoError(t, err) - req.Header.Add("Something", "1234567") - req.Header.Add("X-Forwarded-For", "8.8.8.8,1.1.1.2, 30.30.30.1") - adr, err := GetIPAddress(req) - require.NoError(t, err) - assert.Equal(t, "30.30.30.1", adr) - } - { - req, err := http.NewRequest("GET", "/something", http.NoBody) - assert.NoError(t, err) - req.Header.Add("Something", "1234567") - req.Header.Add("X-Forwarded-For", "8.8.8.8,1.1.1.2,192.168.1.1,10.0.0.65") - adr, err := GetIPAddress(req) - require.NoError(t, err) - assert.Equal(t, "1.1.1.2", adr) - } - { - req, err := http.NewRequest("GET", "/something", http.NoBody) - assert.NoError(t, err) - req.Header.Add("Something", "1234567") - req.Header.Add("X-Forwarded-For", "30.30.30.1") - req.Header.Add("X-Real-Ip", "10.0.0.1") - adr, err := GetIPAddress(req) - require.NoError(t, err) - assert.Equal(t, "30.30.30.1", adr) - } - { - req, err := http.NewRequest("GET", "/something", http.NoBody) - assert.NoError(t, err) - req.Header.Add("Something", "1234567") - req.Header.Add("X-Forwarded-For", "30.30.30.1") - req.Header.Add("X-Real-Ip", "8.8.8.8") - adr, err := GetIPAddress(req) - require.NoError(t, err) - assert.Equal(t, "30.30.30.1", adr) - } - { - req, err := http.NewRequest("GET", "/something", http.NoBody) - assert.NoError(t, err) - req.Header.Add("Something", "1234567") - req.Header.Add("X-Forwarded-For", "10.0.0.2,192.168.1.1") - req.Header.Add("X-Real-Ip", "8.8.8.8") - adr, err := GetIPAddress(req) - require.NoError(t, err) - assert.Equal(t, "8.8.8.8", adr) - } - { - req, err := http.NewRequest("GET", "/something", http.NoBody) - assert.NoError(t, err) - _, err = GetIPAddress(req) - require.Error(t, err) - } -} - -func TestGetIPAddressFromRemoteAddr(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - log.Printf("%v", r) - adr, err := GetIPAddress(r) - require.NoError(t, err) - assert.Equal(t, "127.0.0.1", adr) - })) - - req, err := http.NewRequest("GET", ts.URL+"/something", http.NoBody) - require.NoError(t, err) - client := http.Client{Timeout: time.Second} - _, err = client.Do(req) - require.NoError(t, err) -} - -func TestRealIP(t *testing.T) { - - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - log.Printf("%v", r) - require.Equal(t, "1.2.3.4", r.RemoteAddr) - adr, err := GetIPAddress(r) - require.NoError(t, err) - assert.Equal(t, "1.2.3.4", adr) - }) - - ts := httptest.NewServer(RealIP(handler)) - - req, err := http.NewRequest("GET", ts.URL+"/something", http.NoBody) - require.NoError(t, err) - client := http.Client{Timeout: time.Second} - req.Header.Add("X-Real-IP", "1.2.3.4") - _, err = client.Do(req) - require.NoError(t, err) -}