-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathcsrf.go
248 lines (210 loc) · 6.79 KB
/
csrf.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
// Copyright 2021 Flamego. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
// Package csrf is a middleware that generates and validates CSRF tokens for Flamego.
package csrf
import (
"encoding/gob"
"fmt"
"math/rand"
"net/http"
"reflect"
"time"
"github.com/flamego/flamego"
"github.com/flamego/flamego/inject"
"github.com/flamego/session"
)
func init() {
gob.Register(time.Time{})
}
// CSRF represents a CSRF service and is used to get the current token and
// validate a suspect token.
type CSRF interface {
// Token returns the current token. This is typically used to populate a hidden
// form in an HTML template.
Token() string
// ValidToken validates the passed token against the existing Secret and ID.
ValidToken(t string) bool
// Error executes the error function with given http.ResponseWriter.
Error(w http.ResponseWriter)
// Validate validates CSRF using given context. It attempts to get the token
// from the HTTP header and then the form value. If any of these is found, the
// token will be validated using ValidToken. If the validation fails, custom
// Error is sent as the response. If neither the header nor form value is found,
// http.StatusBadRequest is sent.
Validate(ctx flamego.Context)
}
type csrf struct {
// Header name value for setting and getting CSRF token.
header string
// Form name value for setting and getting CSRF token.
form string
// Token generated to pass via header or hidden form value.
token string
// The value to uniquely identify a user.
id string
// Secret used along with the unique id above to generate the token.
secret string
// The custom function that replies to the request when ValidToken fails.
errorFunc func(w http.ResponseWriter)
}
func (c *csrf) Token() string {
return c.token
}
func (c *csrf) ValidToken(t string) bool {
return ValidToken(t, c.secret, c.id, http.MethodPost)
}
func (c *csrf) Error(w http.ResponseWriter) {
c.errorFunc(w)
}
func (c *csrf) Validate(ctx flamego.Context) {
if token := ctx.Request().Header.Get(c.header); token != "" {
if !c.ValidToken(token) {
c.Error(ctx.ResponseWriter())
}
return
}
if token := ctx.Request().FormValue(c.form); token != "" {
if !c.ValidToken(token) {
c.Error(ctx.ResponseWriter())
}
return
}
http.Error(ctx.ResponseWriter(), "Bad Request: no CSRF token present", http.StatusBadRequest)
}
// Options contains options for the csrf.Csrfer middleware.
type Options struct {
// Secret is the secret value used to generate tokens. Default is an
// auto-generated 10-char random string.
Secret string
// Header specifies which HTTP header to be used to set and get token. Default
// is "X-CSRF-Token".
Header string
// Form specifies which form value to be used to set and get token. Default is
// "_csrf".
Form string
// SessionKey is the session key used to get the unique ID of users. Default is
// "userID".
SessionKey string
// SetHeader indicates whether to send token via Header. Default is false.
SetHeader bool
// NoOrigin indicates whether to disallow Origin appear in the request header.
// Default is false.
NoOrigin bool
// ErrorFunc defines the function to be executed when ValidToken fails.
ErrorFunc func(w http.ResponseWriter)
}
var src = rand.NewSource(time.Now().UnixNano())
// randomBytes generates n random []byte.
func randomBytes(n int) []byte {
const (
letterBytes = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
letterIdxBits = 6 // 6 bits to represent a letter index
letterIdxMask = 1<<letterIdxBits - 1 // All 1-bits, as many as letterIdxBits
letterIdxMax = 63 / letterIdxBits // # of letter indices fitting in 63 bits
)
b := make([]byte, n)
for i, cache, remain := n-1, src.Int63(), letterIdxMax; i >= 0; {
if remain == 0 {
cache, remain = src.Int63(), letterIdxMax
}
if idx := int(cache & letterIdxMask); idx < len(letterBytes) {
b[i] = letterBytes[idx]
i--
}
cache >>= letterIdxBits
remain--
}
return b
}
var _ inject.FastInvoker = (*csrfInvoker)(nil)
// csrfInvoker is an inject.FastInvoker implementation of `func(flamego.Context, session.Session)`.
type csrfInvoker func(flamego.Context, session.Session)
func (invoke csrfInvoker) Invoke(args []interface{}) ([]reflect.Value, error) {
invoke(args[0].(flamego.Context), args[1].(session.Session))
return nil, nil
}
const (
defaultHeader = "X-CSRF-Token"
defaultForm = "_csrf"
defaultSessionKey = "userID"
)
const tokenExpiredAtKey = "flamego::csrf::tokenExpiredAt"
// Csrfer returns a middleware handler that injects csrf.CSRF into the request
// context, and only generates a new CSRF token on every GET request.
func Csrfer(opts ...Options) flamego.Handler {
var opt Options
if len(opts) > 0 {
opt = opts[0]
}
if opt.Secret == "" {
opt.Secret = string(randomBytes(10))
}
if opt.Header == "" {
opt.Header = defaultHeader
}
if opt.Form == "" {
opt.Form = defaultForm
}
if opt.SessionKey == "" {
opt.SessionKey = defaultSessionKey
}
if opt.ErrorFunc == nil {
opt.ErrorFunc = func(w http.ResponseWriter) {
http.Error(w, "Bad Request: invalid CSRF token", http.StatusBadRequest)
}
}
return csrfInvoker(func(c flamego.Context, s session.Session) {
x := &csrf{
secret: opt.Secret,
header: opt.Header,
form: opt.Form,
errorFunc: opt.ErrorFunc,
}
c.MapTo(x, (*CSRF)(nil))
id := s.Get(opt.SessionKey)
if id != nil {
x.id = fmt.Sprintf("%v", id)
} else {
x.id = "0"
}
const oldIDKey = "flamego::csrf::oldID"
const tokenKey = "flamego::csrf::token"
needsNewToken := func(s session.Session, x *csrf) bool {
// The value of ID can change upon user authentication, we need to generate a
// new CSRF token whenever the old and the current ID do not match.
oldID, ok := s.Get(oldIDKey).(string)
if !ok || oldID != x.id {
return true
}
// Check if the current CSRF token has expired.
if expiredAt, ok := s.Get(tokenExpiredAtKey).(time.Time); !ok || !expiredAt.After(time.Now()) {
return true
}
// Check if the session already has a CSRF token
token, ok := s.Get(tokenKey).(string)
if !ok || token == "" {
return true
}
x.token = token
return false
}
if !needsNewToken(s, x) || c.Request().Method != http.MethodGet {
return
}
if opt.NoOrigin && c.Request().Header.Get("Origin") != "" {
return
}
x.token = GenerateToken(x.secret, x.id, http.MethodPost)
s.Set(oldIDKey, x.id)
s.Set(tokenKey, x.token)
s.Set(tokenExpiredAtKey, time.Now().Add(timeout).Add(-1*time.Minute)) // Renew token before the hard timeout
if opt.SetHeader {
c.ResponseWriter().Header().Set(opt.Header, x.token)
}
})
}
// Validate should be used as a per route middleware to validate CSRF tokens.
func Validate(ctx flamego.Context, x CSRF) {
x.Validate(ctx)
}