-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Sanskarzz <[email protected]>
- Loading branch information
Showing
5 changed files
with
170 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
package functions | ||
|
||
import ( | ||
"fmt" | ||
) | ||
|
||
const ( | ||
errorPrefix = "JMESPath function '%s': " | ||
invalidArgumentTypeError = errorPrefix + "argument #%d is not of type %s" | ||
genericError = errorPrefix + "%s" | ||
argOutOfBoundsError = errorPrefix + "%d argument is out of bounds (%d)" | ||
zeroDivisionError = errorPrefix + "Zero divisor passed" | ||
nonIntModuloError = errorPrefix + "Non-integer argument(s) passed for modulo" | ||
typeMismatchError = errorPrefix + "Types mismatch" | ||
nonIntRoundError = errorPrefix + "Non-integer argument(s) passed for round off" | ||
) | ||
|
||
func formatError(format string, function string, values ...any) error { | ||
args := []any{function} | ||
args = append(args, values...) | ||
return fmt.Errorf(format, args...) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
package functions | ||
|
||
import ( | ||
"encoding/base64" | ||
"fmt" | ||
"reflect" | ||
|
||
"github.com/golang-jwt/jwt" | ||
"github.com/jmespath-community/go-jmespath/pkg/functions" | ||
) | ||
|
||
func GetFunctions() []functions.FunctionEntry { | ||
return []functions.FunctionEntry{{ | ||
Name: "jwt_decode", | ||
Arguments: []functions.ArgSpec{ | ||
{Types: []functions.JpType{functions.JpString}}, | ||
{Types: []functions.JpType{functions.JpString}}, | ||
}, | ||
Handler: jwt_decode, | ||
}} | ||
} | ||
|
||
func jwt_decode(arguments []any) (any, error) { | ||
|
||
// Validate argument | ||
tokenString, err := validateArg(" ", arguments, 0, reflect.String) | ||
if err != nil { | ||
return nil, fmt.Errorf("invalidArgumentTypeError: %w", err) | ||
} | ||
tokenStringVal := tokenString.String() | ||
|
||
secretkey, err := validateArg(" ", arguments, 1, reflect.String) | ||
if err != nil { | ||
return nil, fmt.Errorf("invalidArgumentTypeError: %w", err) | ||
} | ||
|
||
// Attempt to decode the base64 encoded secret key | ||
decodedKey, err := base64.StdEncoding.DecodeString(secretkey.String()) | ||
if err != nil { | ||
// If decoding fails, assume the secret key is not base64 encoded | ||
decodedKey = []byte(secretkey.String()) | ||
} | ||
|
||
token, err := jwt.Parse(tokenStringVal, func(token *jwt.Token) (interface{}, error) { | ||
return decodedKey, nil | ||
}) | ||
if err != nil { | ||
return nil, fmt.Errorf("invalid JWT token: %w", err) | ||
} | ||
|
||
result := map[string]any{ | ||
"header": jwt.MapClaims(token.Header), | ||
"payload": jwt.MapClaims(token.Claims.(jwt.MapClaims)), | ||
"sig": fmt.Sprintf("%x", token.Signature), | ||
} | ||
return result, nil | ||
} | ||
|
||
func validateArg(f string, arguments []any, index int, expectedType reflect.Kind) (reflect.Value, error) { | ||
if index >= len(arguments) { | ||
return reflect.Value{}, formatError(argOutOfBoundsError, f, index+1, len(arguments)) | ||
} | ||
if arguments[index] == nil { | ||
return reflect.Value{}, formatError(invalidArgumentTypeError, f, index+1, expectedType.String()) | ||
} | ||
arg := reflect.ValueOf(arguments[index]) | ||
if arg.Type().Kind() != expectedType { | ||
return reflect.Value{}, formatError(invalidArgumentTypeError, f, index+1, expectedType.String()) | ||
} | ||
return arg, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
package functions | ||
|
||
import ( | ||
"fmt" | ||
"reflect" | ||
"sort" | ||
"testing" | ||
) | ||
|
||
func Test_jwt_decode(t *testing.T) { | ||
|
||
token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjIyNDEwODE1MzksIm5iZiI6MTUxNDg1MTEzOSwicm9sZSI6Imd1ZXN0Iiwic3ViIjoiWVd4cFkyVT0ifQ.ja1bgvIt47393ba_WbSBm35NrUhdxM4mOVQN8iXz8lk" | ||
secret := "c2VjcmV0" | ||
type args struct { | ||
arguments []any | ||
} | ||
tests := []struct { | ||
name string | ||
args args | ||
want map[string]any | ||
wantErr bool | ||
}{ | ||
{ | ||
args: args{[]any{token, secret}}, | ||
want: map[string]interface{}{ | ||
"header": map[string]interface{}{ | ||
"alg": "HS256", | ||
"typ": "JWT", | ||
}, | ||
"payload": map[string]interface{}{ | ||
"exp": 2.241081539e+09, | ||
"nbf": 1.514851139e+09, | ||
"role": "guest", | ||
"sub": "YWxpY2U=", | ||
}, | ||
"sig": "6a61316267764974343733393362615f576253426d33354e72556864784d346d4f56514e3869587a386c6b", | ||
}, | ||
wantErr: false, | ||
}, | ||
} | ||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
got, err := jwt_decode(tt.args.arguments) | ||
if (err != nil) != tt.wantErr { | ||
t.Errorf("jwt_decode() error = %v, wantErr %v", err, tt.wantErr) | ||
return | ||
} | ||
gotMap := got.(map[string]any) | ||
wantSorted := sortMap(tt.want) | ||
gotSorted := sortMap(gotMap) | ||
|
||
fmt.Println("Got type:", gotSorted) // To check | ||
fmt.Println("Want type:", wantSorted) // To check | ||
|
||
if !reflect.DeepEqual(gotSorted, wantSorted) { | ||
t.Errorf("jwt_decode() = %v, want %v", gotSorted, wantSorted) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func sortMap(m map[string]interface{}) map[string]interface{} { | ||
keys := make([]string, 0, len(m)) | ||
for k := range m { | ||
keys = append(keys, k) | ||
} | ||
sort.Strings(keys) | ||
|
||
result := make(map[string]interface{}, len(m)) | ||
for _, k := range keys { | ||
result[k] = m[k] | ||
} | ||
return result | ||
} |