Skip to content

Commit

Permalink
Merge pull request #276 from xmidt-org/feature/capabilities
Browse files Browse the repository at this point in the history
Feature/capabilities
  • Loading branch information
johnabass authored Aug 13, 2024
2 parents bd2a716 + 0ab6b62 commit 53b0414
Show file tree
Hide file tree
Showing 8 changed files with 209 additions and 7 deletions.
13 changes: 12 additions & 1 deletion basculejwt/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ import (
"github.com/xmidt-org/bascule/v1"
)

// CapabilitiesKey is the JWT claims key where capabilities are expected.
const CapabilitiesKey = "capabilities"

// Claims exposes standard JWT claims from a Token.
type Claims interface {
// Audience returns the aud field of the JWT.
Expand Down Expand Up @@ -75,6 +78,14 @@ func (t token) Principal() string {
return t.jwt.Subject()
}

func (t token) Capabilities() (caps []string) {
if v, ok := t.jwt.Get(CapabilitiesKey); ok {
caps, _ = bascule.GetCapabilities(v)
}

return
}

// tokenParser is the canonical parser for bascule that deals with JWTs.
// This parser does not use the source.
type tokenParser struct {
Expand All @@ -92,7 +103,7 @@ func NewTokenParser(options ...jwt.ParseOption) (bascule.TokenParser[string], er
}

// Parse parses the value as a JWT, using the parsing options passed to NewTokenParser.
// The returned Token will implement the bascule.Attributes and Claims interfaces.
// The returned Token will implement the bascule.Attributes, bascule.Capabilities, and Claims interfaces.
func (tp *tokenParser) Parse(ctx context.Context, value string) (bascule.Token, error) {
jwtToken, err := jwt.ParseString(value, tp.options...)
if err != nil {
Expand Down
4 changes: 4 additions & 0 deletions basculejwt/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/stretchr/testify/suite"
"github.com/xmidt-org/bascule/v1"
)

type TokenTestSuite struct {
Expand Down Expand Up @@ -125,6 +126,9 @@ func (suite *TokenTestSuite) TestTokenParser() {
suite.Require().NotNil(token)

suite.Equal(suite.subject, token.Principal())
caps, ok := bascule.GetCapabilities(token)
suite.Equal(suite.capabilities, caps)
suite.True(ok)

suite.Require().Implements((*Claims)(nil), token)
claims := token.(Claims)
Expand Down
64 changes: 64 additions & 0 deletions capabilities.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC
// SPDX-License-Identifier: Apache-2.0

package bascule

// CapabilitiesAccessor is an interface that any type may choose to implement
// in order to provide access to any capabilities associated with the token.
// Capabilities do not make sense for all tokens, e.g. simple basic auth tokens.
type CapabilitiesAccessor interface {
// Capabilities returns the set of capabilities associated with this token.
Capabilities() []string
}

// GetCapabilities attempts to convert a value v into a slice of capabilities.
//
// This function provide very flexible values to be used as capabilities. This is
// particularly useful when unmarshalling values, since those values may not be strings
// or slices.
//
// The following conversions are attempted, in order:
//
// (1) If v implements CapabilitiesAccessor, then Capabilities() is returned.
//
// (2) If v is a []string, it is returned as is.
//
// (3) If v is a scalar string, a slice containing only that string is returned.
//
// (4) If v is a []any, a slice containing each element cast to a string is returned.
// If any elements are not castable to string, this function considers that to be the
// same as missing capabilities, i.e. false is returned with an empty slice.
//
// If any conversion was possible, this function returns true even if the capabilities were empty.
// If no such conversion was possible, this function returns false.
func GetCapabilities(v any) (caps []string, ok bool) {
switch vt := v.(type) {
case CapabilitiesAccessor:
caps = vt.Capabilities()
ok = true

case []string:
caps = vt
ok = true

case string:
caps = []string{vt}
ok = true

case []any:
converted := make([]string, 0, len(vt))
for _, raw := range vt {
if element, isString := raw.(string); isString {
converted = append(converted, element)
} else {
break
}
}

if ok = len(converted) == len(vt); ok {
caps = converted
}
}

return
}
97 changes: 97 additions & 0 deletions capabilities_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC
// SPDX-License-Identifier: Apache-2.0

package bascule

import (
"testing"

"github.com/stretchr/testify/suite"
)

type CapabilitiesTestSuite struct {
suite.Suite
}

func (suite *CapabilitiesTestSuite) testGetCapabilitiesNil() {
caps, ok := GetCapabilities(nil)
suite.False(ok)
suite.Empty(caps)
}

func (suite *CapabilitiesTestSuite) testGetCapabilitiesAccessor() {
suite.Run("NoCapabilities", func() {
mt := new(mockToken)
caps, ok := GetCapabilities(mt)
suite.False(ok)
suite.Empty(caps)

mt.AssertExpectations(suite.T())
})

suite.Run("EmptyCapabilities", func() {
mt := new(mockTokenWithCapabilities)
mt.ExpectCapabilities().Once()
caps, ok := GetCapabilities(mt)
suite.True(ok)
suite.Empty(caps)

mt.AssertExpectations(suite.T())
})

suite.Run("HasCapabilities", func() {
mt := new(mockTokenWithCapabilities)
mt.ExpectCapabilities("one", "two", "three").Once()
caps, ok := GetCapabilities(mt)
suite.True(ok)
suite.Equal([]string{"one", "two", "three"}, caps)

mt.AssertExpectations(suite.T())
})
}

func (suite *CapabilitiesTestSuite) testGetCapabilitiesStringSlice() {
suite.Run("Empty", func() {
caps, ok := GetCapabilities([]string{})
suite.True(ok)
suite.Empty(caps)
})

suite.Run("NonEmpty", func() {
caps, ok := GetCapabilities([]string{"one", "two", "three"})
suite.True(ok)
suite.Equal([]string{"one", "two", "three"}, caps)
})
}

func (suite *CapabilitiesTestSuite) testGetCapabilitiesString() {
caps, ok := GetCapabilities("single")
suite.True(ok)
suite.Equal([]string{"single"}, caps)
}

func (suite *CapabilitiesTestSuite) testGetCapabilitiesAnySlice() {
suite.Run("AllStrings", func() {
caps, ok := GetCapabilities([]any{"one", "two", "three"})
suite.True(ok)
suite.Equal([]string{"one", "two", "three"}, caps)
})

suite.Run("NonStrings", func() {
caps, ok := GetCapabilities([]any{"one", 2.0, 3})
suite.False(ok)
suite.Empty(caps)
})
}

func (suite *CapabilitiesTestSuite) TestGetCapabilities() {
suite.Run("Nil", suite.testGetCapabilitiesNil)
suite.Run("Accessor", suite.testGetCapabilitiesAccessor)
suite.Run("StringSlice", suite.testGetCapabilitiesStringSlice)
suite.Run("String", suite.testGetCapabilitiesString)
suite.Run("AnySlice", suite.testGetCapabilitiesAnySlice)
}

func TestCapabilities(t *testing.T) {
suite.Run(t, new(CapabilitiesTestSuite))
}
30 changes: 28 additions & 2 deletions mocks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,35 @@ import (
"github.com/stretchr/testify/mock"
)

type testToken string
type stubToken string

func (tt testToken) Principal() string { return string(tt) }
func (t stubToken) Principal() string { return string(t) }

type mockToken struct {
mock.Mock
}

func (m *mockToken) Principal() string {
return m.Called().String(0)
}

func (m *mockToken) ExpectPrincipal(v string) *mock.Call {
return m.On("Principal").Return(v)
}

type mockTokenWithCapabilities struct {
mockToken
}

func (m *mockTokenWithCapabilities) Capabilities() []string {
args := m.Called()
caps, _ := args.Get(0).([]string)
return caps
}

func (m *mockTokenWithCapabilities) ExpectCapabilities(caps ...string) *mock.Call {
return m.On("Capabilities").Return(caps)
}

type mockValidator[S any] struct {
mock.Mock
Expand Down
2 changes: 1 addition & 1 deletion testSuite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func (suite *TestSuite) testContext() context.Context {
}

func (suite *TestSuite) testToken() Token {
return testToken("test")
return stubToken("test")
}

func (suite *TestSuite) contexter(ctx context.Context) Contexter {
Expand Down
2 changes: 1 addition & 1 deletion token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type TokenParserSuite struct {
func (suite *TokenParserSuite) SetupSuite() {
suite.expectedCtx = suite.testContext()
suite.expectedSource = 123
suite.expectedToken = testToken("expected token")
suite.expectedToken = stubToken("expected token")
suite.expectedErr = errors.New("expected token parser error")
}

Expand Down
4 changes: 2 additions & 2 deletions validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ type ValidatorsTestSuite struct {
func (suite *ValidatorsTestSuite) SetupSuite() {
suite.expectedCtx = suite.testContext()
suite.expectedSource = 123
suite.inputToken = testToken("input token")
suite.outputToken = testToken("output token")
suite.inputToken = stubToken("input token")
suite.outputToken = stubToken("output token")
suite.expectedErr = errors.New("expected validator error")
}

Expand Down

0 comments on commit 53b0414

Please sign in to comment.