This repository has been archived by the owner on Feb 24, 2024. It is now read-only.
-
-
Notifications
You must be signed in to change notification settings - Fork 582
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
moved the CSRF middleware into it's own package to cut down on namespace
polution
- Loading branch information
Showing
3 changed files
with
269 additions
and
218 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,233 +1,23 @@ | ||
package middleware | ||
|
||
import ( | ||
"crypto/rand" | ||
"crypto/subtle" | ||
"encoding/base64" | ||
"errors" | ||
"net/http" | ||
"net/url" | ||
|
||
"github.com/gobuffalo/buffalo" | ||
) | ||
|
||
const ( | ||
// CSRF token length in bytes. | ||
csrfTokenLength int = 32 | ||
csrfTokenKey string = "authenticity_token" | ||
) | ||
|
||
var ( | ||
// The name value used in form fields. | ||
fieldName = csrfTokenKey | ||
|
||
// The HTTP request header to inspect | ||
headerName = "X-CSRF-Token" | ||
|
||
// Idempotent (safe) methods as defined by RFC7231 section 4.2.2. | ||
safeMethods = []string{"GET", "HEAD", "OPTIONS", "TRACE"} | ||
"github.com/gobuffalo/buffalo/middleware/csrf" | ||
) | ||
|
||
var ( | ||
// ErrNoReferer is returned when a HTTPS request provides an empty Referer | ||
// header. | ||
ErrNoReferer = errors.New("referer not supplied") | ||
ErrNoReferer = csrf.ErrNoReferer | ||
// ErrBadReferer is returned when the scheme & host in the URL do not match | ||
// the supplied Referer header. | ||
ErrBadReferer = errors.New("referer invalid") | ||
ErrBadReferer = csrf.ErrBadReferer | ||
// ErrNoCSRFToken is returned if no CSRF token is supplied in the request. | ||
ErrNoCSRFToken = errors.New("CSRF token not found in request") | ||
ErrNoCSRFToken = csrf.ErrNoToken | ||
// ErrBadCSRFToken is returned if the CSRF token in the request does not match | ||
// the token in the session, or is otherwise malformed. | ||
ErrBadCSRFToken = errors.New("CSRF token invalid") | ||
ErrBadCSRFToken = csrf.ErrBadToken | ||
) | ||
|
||
// CSRF enable CSRF protection on routes using this middleware. | ||
// This middleware is adapted from gorilla/csrf | ||
var CSRF = func(next buffalo.Handler) buffalo.Handler { | ||
return func(c buffalo.Context) error { | ||
var realToken []byte | ||
rawRealToken := c.Session().Get(csrfTokenKey) | ||
|
||
if rawRealToken == nil || len(rawRealToken.([]byte)) != csrfTokenLength { | ||
// If the token is missing, or the length if the token is wrong, | ||
// generate a new token. | ||
realToken, err := generateRandomBytes(csrfTokenLength) | ||
if err != nil { | ||
return err | ||
} | ||
// Save the new real token in session | ||
c.Session().Set(csrfTokenKey, realToken) | ||
} else { | ||
realToken = rawRealToken.([]byte) | ||
} | ||
|
||
// Set masked token in context data, to be available in template | ||
c.Set(fieldName, mask(realToken, c.Request())) | ||
|
||
// HTTP methods not defined as idempotent ("safe") under RFC7231 require | ||
// inspection. | ||
if !contains(safeMethods, c.Request().Method) { | ||
// Enforce an origin check for HTTPS connections. As per the Django CSRF | ||
// implementation (https://goo.gl/vKA7GE) the Referer header is almost | ||
// always present for same-domain HTTP requests. | ||
if c.Request().URL.Scheme == "https" { | ||
// Fetch the Referer value. Call the error handler if it's empty or | ||
// otherwise fails to parse. | ||
referer, err := url.Parse(c.Request().Referer()) | ||
if err != nil || referer.String() == "" { | ||
return ErrNoReferer | ||
} | ||
|
||
if sameOrigin(c.Request().URL, referer) == false { | ||
return ErrBadReferer | ||
} | ||
} | ||
|
||
// Retrieve the combined token (pad + masked) token and unmask it. | ||
requestToken := unmask(requestCSRFToken(c.Request())) | ||
|
||
// Missing token | ||
if requestToken == nil { | ||
return ErrNoCSRFToken | ||
} | ||
|
||
// Compare tokens | ||
if !compareTokens(requestToken, realToken) { | ||
return ErrBadCSRFToken | ||
} | ||
} | ||
|
||
return next(c) | ||
} | ||
} | ||
|
||
// generateRandomBytes returns securely generated random bytes. | ||
// It will return an error if the system's secure random number generator | ||
// fails to function correctly. | ||
func generateRandomBytes(n int) ([]byte, error) { | ||
b := make([]byte, n) | ||
_, err := rand.Read(b) | ||
// err == nil only if len(b) == n | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return b, nil | ||
} | ||
|
||
// sameOrigin returns true if URLs a and b share the same origin. The same | ||
// origin is defined as host (which includes the port) and scheme. | ||
func sameOrigin(a, b *url.URL) bool { | ||
return (a.Scheme == b.Scheme && a.Host == b.Host) | ||
} | ||
|
||
// contains is a helper function to check if a string exists in a slice - e.g. | ||
// whether a HTTP method exists in a list of safe methods. | ||
func contains(vals []string, s string) bool { | ||
for _, v := range vals { | ||
if v == s { | ||
return true | ||
} | ||
} | ||
|
||
return false | ||
} | ||
|
||
// compare securely (constant-time) compares the unmasked token from the request | ||
// against the real token from the session. | ||
func compareTokens(a, b []byte) bool { | ||
// This is required as subtle.ConstantTimeCompare does not check for equal | ||
// lengths in Go versions prior to 1.3. | ||
if len(a) != len(b) { | ||
return false | ||
} | ||
|
||
return subtle.ConstantTimeCompare(a, b) == 1 | ||
} | ||
|
||
// xorToken XORs tokens ([]byte) to provide unique-per-request CSRF tokens. It | ||
// will return a masked token if the base token is XOR'ed with a one-time-pad. | ||
// An unmasked token will be returned if a masked token is XOR'ed with the | ||
// one-time-pad used to mask it. | ||
func xorToken(a, b []byte) []byte { | ||
n := len(a) | ||
if len(b) < n { | ||
n = len(b) | ||
} | ||
|
||
res := make([]byte, n) | ||
|
||
for i := 0; i < n; i++ { | ||
res[i] = a[i] ^ b[i] | ||
} | ||
|
||
return res | ||
} | ||
|
||
// mask returns a unique-per-request token to mitigate the BREACH attack | ||
// as per http://breachattack.com/#mitigations | ||
// | ||
// The token is generated by XOR'ing a one-time-pad and the base (session) CSRF | ||
// token and returning them together as a 64-byte slice. This effectively | ||
// randomises the token on a per-request basis without breaking multiple browser | ||
// tabs/windows. | ||
func mask(realToken []byte, r *http.Request) string { | ||
otp, err := generateRandomBytes(csrfTokenLength) | ||
if err != nil { | ||
return "" | ||
} | ||
|
||
// XOR the OTP with the real token to generate a masked token. Append the | ||
// OTP to the front of the masked token to allow unmasking in the subsequent | ||
// request. | ||
return base64.StdEncoding.EncodeToString(append(otp, xorToken(otp, realToken)...)) | ||
} | ||
|
||
// unmask splits the issued token (one-time-pad + masked token) and returns the | ||
// unmasked request token for comparison. | ||
func unmask(issued []byte) []byte { | ||
// Issued tokens are always masked and combined with the pad. | ||
if len(issued) != csrfTokenLength*2 { | ||
return nil | ||
} | ||
|
||
// We now know the length of the byte slice. | ||
otp := issued[csrfTokenLength:] | ||
masked := issued[:csrfTokenLength] | ||
|
||
// Unmask the token by XOR'ing it against the OTP used to mask it. | ||
return xorToken(otp, masked) | ||
} | ||
|
||
// requestCSRFToken gets the CSRF token from either: | ||
// - a HTTP header | ||
// - a form value | ||
// - a multipart form value | ||
func requestCSRFToken(r *http.Request) []byte { | ||
// 1. Check the HTTP header first. | ||
issued := r.Header.Get(headerName) | ||
|
||
// 2. Fall back to the POST (form) value. | ||
if issued == "" { | ||
issued = r.PostFormValue(fieldName) | ||
} | ||
|
||
// 3. Finally, fall back to the multipart form (if set). | ||
if issued == "" && r.MultipartForm != nil { | ||
vals := r.MultipartForm.Value[fieldName] | ||
|
||
if len(vals) > 0 { | ||
issued = vals[0] | ||
} | ||
} | ||
|
||
// Decode the "issued" (pad + masked) token sent in the request. Return a | ||
// nil byte slice on a decoding error (this will fail upstream). | ||
decoded, err := base64.StdEncoding.DecodeString(issued) | ||
if err != nil { | ||
return nil | ||
} | ||
|
||
return decoded | ||
} | ||
var CSRF = csrf.Middleware |
Oops, something went wrong.