forked from MicahParks/keyfunc
-
Notifications
You must be signed in to change notification settings - Fork 0
/
jwks.go
397 lines (363 loc) · 11.7 KB
/
jwks.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
package keyfunc
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"sync"
"time"
"github.com/golang-jwt/jwt"
)
var (
// ErrJWKAlgMismatch indicates that the given JWK was found, but its "alg" parameter's value did not match that of
// the JWT.
ErrJWKAlgMismatch = errors.New(`the given JWK was found, but its "alg" parameter's value did not match the expected algorithm`)
// ErrNoMatchingKey indicated no JWKey is matching the token according ti SingleStore rules.
ErrNoMatchingKey = errors.New("no key can be matched to validate the token")
// ErrMissingAssets indicates there are required assets are missing to create a public key.
ErrMissingAssets = errors.New("required assets are missing to create a public key")
)
// ErrorHandler is a function signature that consumes an error.
type ErrorHandler func(err error)
const (
// UseEncryption is a JWK "use" parameter value indicating the JSON Web Key is to be used for encryption.
UseEncryption JWKUse = "enc"
// UseOmitted is a JWK "use" parameter value that was not specified or was empty.
UseOmitted JWKUse = ""
// UseSignature is a JWK "use" parameter value indicating the JSON Web Key is to be used for signatures.
UseSignature JWKUse = "sig"
)
// JWKUse is a set of values for the "use" parameter of a JWK.
// See https://tools.ietf.org/html/rfc7517#section-4.2.
type JWKUse string
// JsonWebKey represents a JSON Web Key inside a JWKS.
type JsonWebKey struct {
Algorithm string `json:"alg"`
Curve string `json:"crv"`
Exponent string `json:"e"`
K string `json:"k"`
ID string `json:"kid"`
Modulus string `json:"n"`
Type string `json:"kty"`
Use string `json:"use"`
X string `json:"x"`
Y string `json:"y"`
UsernameFrom string `json:"usernameFrom"`
Audience interface{} `json:"aud"`
}
// parsedJWK represents a JSON Web Key parsed with fields as the correct Go types.
type ParsedJWK struct {
algorithm string
kty string
Public interface{}
use JWKUse
Jwk *JsonWebKey
kid string
audience []string
}
// JWKS represents a JSON Web Key Set (JWK Set).
type JWKS struct {
jwkUseWhitelist map[JWKUse]struct{}
cancel context.CancelFunc
client *http.Client
ctx context.Context
raw []byte
givenKeys map[string]GivenKey
givenKIDOverride bool
jwksURL string
keys map[string][]ParsedJWK
mux sync.RWMutex
refreshErrorHandler ErrorHandler
refreshInterval time.Duration
refreshRateLimit time.Duration
refreshRequests chan context.CancelFunc
refreshTimeout time.Duration
initAsync bool
refreshUnknownKID bool
requestFactory func(ctx context.Context, url string) (*http.Request, error)
responseExtractor func(ctx context.Context, resp *http.Response) (json.RawMessage, error)
}
// rawJWKS represents a JWKS in JSON format.
type rawJWKS struct {
Keys []*JsonWebKey `json:"keys"`
}
// NewJSON creates a new JWKS from a raw JSON message.
func NewJSON(jwksBytes json.RawMessage) (jwks *JWKS, err error) {
var rawKS rawJWKS
err = json.Unmarshal(jwksBytes, &rawKS)
if err != nil {
return nil, err
}
// Iterate through the keys in the raw JWKS. Add them to the JWKS.
jwks = &JWKS{
keys: make(map[string][]ParsedJWK, len(rawKS.Keys)),
}
for _, key := range rawKS.Keys {
var keyInter interface{}
switch keyType := key.Type; keyType {
case ktyEC:
keyInter, err = key.ECDSA()
if err != nil {
continue
}
case ktyOKP:
keyInter, err = key.EdDSA()
if err != nil {
continue
}
case ktyOct:
keyInter, err = key.Oct()
if err != nil {
continue
}
case ktyRSA:
keyInter, err = key.RSA()
if err != nil {
continue
}
default:
// Ignore unknown key types silently.
continue
}
audience := make([]string, 0)
if key.Audience != nil {
if audStr, ok := key.Audience.(string); ok {
audience = strings.Split(audStr, ",")
} else if audList, ok := key.Audience.([]string); ok {
audience = audList
}
}
for idx := range audience {
audience[idx] = strings.TrimSpace(audience[idx])
}
jwks.keys[key.ID] = append(jwks.keys[key.ID],
ParsedJWK{
algorithm: key.Algorithm,
kty: key.Type,
use: JWKUse(key.Use),
Public: keyInter,
Jwk: key,
kid: key.ID,
audience: audience,
})
}
return jwks, nil
}
// EndBackground ends the background goroutine to update the JWKS. It can only happen once and is only effective if the
// JWKS has a background goroutine refreshing the JWKS keys.
func (j *JWKS) EndBackground() {
if j.cancel != nil {
j.cancel()
}
}
// KIDs returns the key IDs (`kid`) for all keys in the JWKS.
func (j *JWKS) KIDs() (kids []string) {
j.mux.RLock()
defer j.mux.RUnlock()
kids = make([]string, len(j.keys))
index := 0
for kid := range j.keys {
kids[index] = kid
index++
}
return kids
}
// Len returns the number of keys in the JWKS.
func (j *JWKS) Len() int {
j.mux.RLock()
defer j.mux.RUnlock()
result := 0
for _, val := range j.keys {
result += len(val)
}
return result
}
// RawJWKS returns a copy of the raw JWKS received from the given JWKS URL.
func (j *JWKS) RawJWKS() []byte {
j.mux.RLock()
defer j.mux.RUnlock()
raw := make([]byte, len(j.raw))
copy(raw, j.raw)
return raw
}
// ReadOnlyKeys returns a read-only copy of the mapping of key IDs (`kid`) to cryptographic keys.
// Currently this function is used for test purposes only
func (j *JWKS) ReadOnlyKeys() map[string]interface{} {
keys := make(map[string]interface{})
j.mux.Lock()
for kid, parsedKey := range j.keys {
// TODO: generalize this function to account for multiple keys with a given kid
keys[kid] = parsedKey[0].Public
}
j.mux.Unlock()
return keys
}
func (j *JWKS) canUseKey(key ParsedJWK) bool {
canUseKey := true
// jwkUseWhitelist might be empty if the jwks was from keyfunc.NewJSON() or if JWKUseNoWhitelist option was true.
// in this case we don't restrict "use" parameter
if len(j.jwkUseWhitelist) > 0 {
_, canUseKey = j.jwkUseWhitelist[key.use]
}
return canUseKey
}
func checkSlicesIntersect(slice1 []string, slice2 []string) bool {
for _, v1 := range slice1 {
for _, v2 := range slice2 {
if v1 == v2 {
return true
}
}
}
return false
}
// GetTypeForAlg returns the corresponding Key Type (kty)
// for a given `alg` value.
// kty: https://www.rfc-editor.org/rfc/rfc7518#section-7.4.2
// alg: https://www.rfc-editor.org/rfc/rfc7518#section-7.1.2
func GetTypeForAlg(alg string) string {
switch alg {
case "RS256", "RS384", "RS512":
return "RSA"
case "ES384", "ES256", "ES512":
return "EC"
case "HS256", "HS384", "HS512":
return "oct"
case "EdDSA":
return "OKP"
}
return ""
}
func (j *JWKS) filterKeys(alg string, token *jwt.Token, parsedKeys []ParsedJWK) []*ParsedJWK {
var result []*ParsedJWK
for idx, key := range parsedKeys {
if (key.algorithm == alg || (key.algorithm == "" && GetTypeForAlg(alg) == key.kty)) && j.canUseKey(key) {
audienceMatch := false
// https://docs.singlestore.com/db/v7.8/en/security/authentication/authenticate-via-jwt.html#validate-jwts-with-jwks
// If the matching JWK includes an aud (Audience) field which does not match the aud field in the JWT, then the authentication request is rejected.
// The aud field can be a string or an array of strings. If any aud string of the JWT matches any aud string of the JWK, it is considered a match.
if len(key.audience) == 0 {
// If the matching JWK does not define an audience (aud), audience checking is skipped. Note that aud is not a standard field in JWK.
audienceMatch = true
} else {
var audToken []string
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return result
}
audTokenInter, ok := claims["aud"]
if ok {
audTokenStr, ok := audTokenInter.(string)
if ok {
// audToken is a comma-separated string
audToken = strings.Split(audTokenStr, ",")
} else {
// audToken is an array of strings
audToken, _ = audTokenInter.([]string)
}
for idx := range audToken {
audToken[idx] = strings.TrimSpace(audToken[idx])
}
// If any aud string of the JWT matches any aud string of the JWK, it is considered a match
audienceMatch = checkSlicesIntersect(audToken, key.audience)
}
}
if audienceMatch {
result = append(result, &parsedKeys[idx])
}
}
}
return result
}
// GetMatchingKeys implements the logic described in
// https://docs.singlestore.com/db/v7.8/en/security/authentication/authenticate-via-jwt.html
// A Read Lock for `j.mux` is acquired when the JWKS is read
//
// JWTs are matched with JSON Web Keys (JWKs) for validation as follows:
// 1. If the JWT has a kid (Key ID) field, the JWKs with matching kid fields are validated.
// 2. If the JWT has a kid field that doesn’t match any JWK or jwt_config key, the authentication request is rejected. See Validate JWTs with the jwt-config for more information.
// 3. If the JWT has an iss (Issuer) field (instead of a kid field) that matches the kid in one or more JWKs, the JWKs with matching kid fields are validated.
// 4. If the JWT does not have a kid field and the iss field does not match the kid field in any JWK, then validation is attempted with all the JWKs with a matching alg (Algorithm) field. If the alg field is not specified, the kty (Key Type) field is used instead.
func (j *JWKS) GetMatchingKeys(token *jwt.Token) ([]*ParsedJWK, error) {
var result []*ParsedJWK
// alg must be present in jwt
var alg string
algInter, ok := token.Header["alg"]
if ok {
alg, ok = algInter.(string)
if !ok {
return result, fmt.Errorf("could not convert `alg` in JWT header to string")
}
} else {
return result, fmt.Errorf("could not validate a JWT without `alg` in header")
}
var kid string
kidInter, ok := token.Header["kid"]
if ok {
kid, ok = kidInter.(string)
if !ok {
return result, fmt.Errorf("could not convert `kid` in JWT header to string")
}
}
j.mux.RLock()
defer j.mux.RUnlock()
if kid != "" {
if parsedKeys, ok := j.keys[kid]; ok {
result = j.filterKeys(alg, token, parsedKeys)
}
// when "kid" is present in JWT, we match only keys with the same kid
return result, nil
}
var iss string
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return result, fmt.Errorf("cannot get claims from the token %s", token.Raw)
}
issInter, ok := claims["iss"]
// iss is present in jwt
if ok {
iss, ok = issInter.(string)
if !ok {
return result, fmt.Errorf("could not convert `iss` in JWT header to string")
}
}
if iss != "" {
if parsedKeys, ok := j.keys[iss]; ok {
result = j.filterKeys(alg, token, parsedKeys)
}
}
// no "kid" and no match with "iss", use "alg" only
if len(result) == 0 {
for _, parsedKeys := range j.keys {
currentResult := j.filterKeys(alg, token, parsedKeys)
result = append(result, currentResult...)
}
}
return result, nil
}
// GetMatchingKeysWithRefresh gets the keys according to SingleStore logic,
// and if `j.refreshUnknownKID` is set to `true`, performs jwks refresh if no key was matched
func (j *JWKS) GetMatchingKeysWithRefresh(token *jwt.Token) []*ParsedJWK {
matchingKeys, _ := j.GetMatchingKeys(token)
if len(matchingKeys) == 0 {
if !j.refreshUnknownKID {
return matchingKeys
}
ctx, cancel := context.WithCancel(j.ctx)
// Refresh the JWKS.
select {
case <-j.ctx.Done():
return matchingKeys
case j.refreshRequests <- cancel:
default:
// If the j.refreshRequests channel is full, just return matchingKeys
return matchingKeys
}
// Wait for the JWKS refresh to finish.
<-ctx.Done()
matchingKeys, _ = j.GetMatchingKeys(token)
}
return matchingKeys
}