Skip to content

Commit

Permalink
feat: add jwt decode function
Browse files Browse the repository at this point in the history
Signed-off-by: Sanskarzz <[email protected]>
  • Loading branch information
Sanskarzz committed Mar 27, 2024
1 parent d1edd80 commit 8607a44
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 0 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ require (
github.com/envoyproxy/protoc-gen-validate v1.0.4 // indirect
github.com/go-logr/logr v1.4.1 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/golang/protobuf v1.5.3 // indirect
github.com/gopherjs/gopherjs v1.17.2 // indirect
github.com/jtolds/gls v4.20.0+incompatible // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ=
github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
Expand Down
22 changes: 22 additions & 0 deletions pkg/functions/error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package functions

import (
"fmt"
)

const (
errorPrefix = "JMESPath function '%s': "
invalidArgumentTypeError = errorPrefix + "argument #%d is not of type %s"
genericError = errorPrefix + "%s"
argOutOfBoundsError = errorPrefix + "%d argument is out of bounds (%d)"
zeroDivisionError = errorPrefix + "Zero divisor passed"
nonIntModuloError = errorPrefix + "Non-integer argument(s) passed for modulo"
typeMismatchError = errorPrefix + "Types mismatch"
nonIntRoundError = errorPrefix + "Non-integer argument(s) passed for round off"
)

func formatError(format string, function string, values ...any) error {
args := []any{function}
args = append(args, values...)
return fmt.Errorf(format, args...)
}
71 changes: 71 additions & 0 deletions pkg/functions/functions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package functions

import (
"encoding/base64"
"fmt"
"reflect"

"github.com/golang-jwt/jwt"
"github.com/jmespath-community/go-jmespath/pkg/functions"
)

func GetFunctions() []functions.FunctionEntry {
return []functions.FunctionEntry{{
Name: "jwt_decode",
Arguments: []functions.ArgSpec{
{Types: []functions.JpType{functions.JpString}},
{Types: []functions.JpType{functions.JpString}},
},
Handler: jwt_decode,
}}
}

func jwt_decode(arguments []any) (any, error) {

// Validate argument
tokenString, err := validateArg(" ", arguments, 0, reflect.String)
if err != nil {
return nil, fmt.Errorf("invalidArgumentTypeError: %w", err)
}
tokenStringVal := tokenString.String()

secretkey, err := validateArg(" ", arguments, 1, reflect.String)
if err != nil {
return nil, fmt.Errorf("invalidArgumentTypeError: %w", err)
}

// Attempt to decode the base64 encoded secret key
decodedKey, err := base64.StdEncoding.DecodeString(secretkey.String())
if err != nil {
// If decoding fails, assume the secret key is not base64 encoded
decodedKey = []byte(secretkey.String())
}

token, err := jwt.Parse(tokenStringVal, func(token *jwt.Token) (interface{}, error) {
return decodedKey, nil
})
if err != nil {
return nil, fmt.Errorf("invalid JWT token: %w", err)
}

result := map[string]any{
"header": jwt.MapClaims(token.Header),
"payload": jwt.MapClaims(token.Claims.(jwt.MapClaims)),
"sig": fmt.Sprintf("%x", token.Signature),
}
return result, nil
}

func validateArg(f string, arguments []any, index int, expectedType reflect.Kind) (reflect.Value, error) {
if index >= len(arguments) {
return reflect.Value{}, formatError(argOutOfBoundsError, f, index+1, len(arguments))
}
if arguments[index] == nil {
return reflect.Value{}, formatError(invalidArgumentTypeError, f, index+1, expectedType.String())
}
arg := reflect.ValueOf(arguments[index])
if arg.Type().Kind() != expectedType {
return reflect.Value{}, formatError(invalidArgumentTypeError, f, index+1, expectedType.String())
}
return arg, nil
}
74 changes: 74 additions & 0 deletions pkg/functions/functions_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package functions

import (
"fmt"
"reflect"
"sort"
"testing"
)

func Test_jwt_decode(t *testing.T) {

token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjIyNDEwODE1MzksIm5iZiI6MTUxNDg1MTEzOSwicm9sZSI6Imd1ZXN0Iiwic3ViIjoiWVd4cFkyVT0ifQ.ja1bgvIt47393ba_WbSBm35NrUhdxM4mOVQN8iXz8lk"
secret := "c2VjcmV0"
type args struct {
arguments []any
}
tests := []struct {
name string
args args
want map[string]any
wantErr bool
}{
{
args: args{[]any{token, secret}},
want: map[string]interface{}{
"header": map[string]interface{}{
"alg": "HS256",
"typ": "JWT",
},
"payload": map[string]interface{}{
"exp": 2.241081539e+09,
"nbf": 1.514851139e+09,
"role": "guest",
"sub": "YWxpY2U=",
},
"sig": "6a61316267764974343733393362615f576253426d33354e72556864784d346d4f56514e3869587a386c6b",
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := jwt_decode(tt.args.arguments)
if (err != nil) != tt.wantErr {
t.Errorf("jwt_decode() error = %v, wantErr %v", err, tt.wantErr)
return
}
gotMap := got.(map[string]any)
wantSorted := sortMap(tt.want)
gotSorted := sortMap(gotMap)

fmt.Println("Got type:", gotSorted) // To check
fmt.Println("Want type:", wantSorted) // To check

if !reflect.DeepEqual(gotSorted, wantSorted) {
t.Errorf("jwt_decode() = %v, want %v", gotSorted, wantSorted)
}
})
}
}

func sortMap(m map[string]interface{}) map[string]interface{} {
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
sort.Strings(keys)

result := make(map[string]interface{}, len(m))
for _, k := range keys {
result[k] = m[k]
}
return result
}

0 comments on commit 8607a44

Please sign in to comment.