-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtoken.go
141 lines (116 loc) · 3.81 KB
/
token.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
// There are two types of tokens.
//
// The unmasked "real" token consists of 32 random bytes.
// It is stored in a cookie (base64 encoded) and it's the
// "reference" value that sent tokens get compared to.
//
// The masked "sent" token consists of 64 bytes:
// 32 byte key used for one-time pad masking.
// 32 byte "real" token masked with said key.
// It is used as the CSRF token value (base64 encoded
// as well) in forms and/or headers.
package nosurfctx
import (
"context"
"crypto/rand"
"crypto/subtle"
"fmt"
"io"
"net/http"
)
const (
// cookieName is the name of the CSRF cookie.
cookieName = "csrf_token"
// formFieldName is the name of the form field.
formFieldName = "csrf_token"
// headerName is the name of the CSRF header.
headerName = "X-CSRF-Token"
// tokenLength is the token length.
tokenLength = 32
// maxAge is the max-age in seconds for the CSRF cookie, 365 days.
maxAge = 365 * 24 * 60 * 60
)
// key is the key type used by this package for context.
type key int
// csrfKey is the key for storing and retrieving the token from context.
var csrfKey key = 1
// Token gets the token from the given request's context.
func Token(r *http.Request) string {
return r.Context().Value(csrfKey).(string)
}
// generateToken generates a new token consisting of random bytes.
func generateToken() ([]byte, error) {
bytes := make([]byte, tokenLength)
if _, err := io.ReadFull(rand.Reader, bytes); err != nil {
return nil, fmt.Errorf("Could not generate random bytes for token: %s", err)
}
return bytes, nil
}
// getTokenFromCookie gets the token from the CSRF cookie.
func getTokenFromCookie(r *http.Request) []byte {
var token []byte
cookie, err := r.Cookie(cookieName)
if err == nil {
token = b64decode(cookie.Value)
}
return token
}
// getTokenFromRequest gets the token from the request.
func getTokenFromRequest(r *http.Request) []byte {
var token string
// Prefer the header over form value.
token = r.Header.Get(headerName)
// Then POST values.
if len(token) == 0 {
token = r.PostFormValue(formFieldName)
}
// If all else fails, try a multipart value.
// PostFormValue() will have already called ParseMultipartForm().
if len(token) == 0 && r.MultipartForm != nil {
vals := r.MultipartForm.Value[formFieldName]
if len(vals) != 0 {
token = vals[0]
}
}
return b64decode(token)
}
// setTokenCookie sets the CSRF cookie containing the given token.
func setTokenCookie(w http.ResponseWriter, token []byte) {
// Create a new http.Cookie with the base64 encoded masked token.
cookie := http.Cookie{}
cookie.Name = cookieName
cookie.Value = b64encode(token)
cookie.Path = "/"
cookie.MaxAge = maxAge
// Set the cookie.
http.SetCookie(w, &cookie)
}
// setTokenContext sets the given token to the given request's Context.
// The value stored in context will be the masked and base64 encoded
// token for use in forms and ajax.
func setTokenContext(r *http.Request, token []byte) (*http.Request, error) {
// Mask the token.
maskedToken, err := maskToken(token)
if err != nil {
return r, err
}
return r.WithContext(context.WithValue(r.Context(), csrfKey, b64encode(maskedToken))), nil
}
// verifyToken verifies the sent token matches the real token.
// realToken should be a base64 decoded 32 byte slice.
// sentToken should be a base64 decoded 64 byte slice.
func verifyToken(realToken, sentToken []byte) (bool, error) {
realN := len(realToken)
sentN := len(sentToken)
if realN != tokenLength || sentN != tokenLength*2 {
return false, fmt.Errorf("Sent token length does not match real token length")
}
// Unmask the sent token.
sentPlain, err := unmaskToken(sentToken)
if err != nil {
return false, err
}
// Compare the real token to the sent token using a constant
// time compare function to prevent info from leaking.
return subtle.ConstantTimeCompare(realToken, sentPlain) == 1, nil
}