Skip to content

Commit

Permalink
feat: allow claims to be used for response headers
Browse files Browse the repository at this point in the history
- allow token claims to be used for response headers
- add optional config to determine which claims will be added as what header
- update tests
  • Loading branch information
denopink committed Nov 20, 2024
1 parent bf64eaa commit 51bda72
Show file tree
Hide file tree
Showing 12 changed files with 215 additions and 44 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ report.json
# for VSCode
.vscode/
.dev/
__debug_bin*

# for releases
.ignore/
Expand Down
6 changes: 5 additions & 1 deletion token/claimBuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ var (
ErrMissingKey = errors.New("A key is required for all claims and metadata values")
)

const (
jwtExpName = "exp"
)

// ClaimBuilder is a strategy for building token claims, given a token Request
type ClaimBuilder interface {
AddClaims(context.Context, *Request, map[string]interface{}) error
Expand Down Expand Up @@ -88,7 +92,7 @@ func (tc *timeClaimBuilder) AddClaims(_ context.Context, r *Request, target map[
target["iat"] = now.Unix()

if tc.duration > 0 {
target["exp"] = now.Add(tc.duration).Unix()
target[jwtExpName] = now.Add(tc.duration).Unix()
}

if !tc.disableNotBefore {
Expand Down
68 changes: 62 additions & 6 deletions token/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,81 @@ package token

import (
"context"
"errors"
"net/http"
"time"

"github.com/go-kit/kit/endpoint"
)

const (
JWTExpHeader = "Expires"
)

var (
ErrJWTExpType = errors.New("expected jwt to be of type time.Time")
)

// NewIssueEndpoint returns a go-kit endpoint for a token factory's NewToken method
func NewIssueEndpoint(f Factory) endpoint.Endpoint {
func NewIssueEndpoint(f Factory, headerClaims map[string]string) endpoint.Endpoint {
return func(ctx context.Context, v interface{}) (interface{}, error) {
return f.NewToken(ctx, v.(*Request))
req := v.(*Request)
resp := Response{
Claims: make(map[string]interface{}, len(req.Claims)),
HeaderClaims: headerClaims,
}
token, err := f.NewToken(ctx, req, resp.Claims)
if err != nil {
return Response{}, err
}

resp.Body = []byte(token)

return resp, err
}
}

// NewClaimsEndpoint returns a go-kit endpoint that returns just the claims
func NewClaimsEndpoint(cb ClaimBuilder) endpoint.Endpoint {
func NewClaimsEndpoint(cb ClaimBuilder, headerClaims map[string]string) endpoint.Endpoint {
return func(ctx context.Context, v interface{}) (interface{}, error) {
merged := make(map[string]interface{})
if err := cb.AddClaims(ctx, v.(*Request), merged); err != nil {
resp := Response{
Claims: make(map[string]interface{}),
HeaderClaims: headerClaims,
}
if err := cb.AddClaims(ctx, v.(*Request), resp.Claims); err != nil {
return nil, err
}

return merged, nil
return resp.Claims, nil
}
}

type Response struct {
// Claims is the set of token claims.
Claims map[string]interface{}
// HeaderClaims is a map of claims-to-headers, where each claim will be attempted to be added as response header.
HeaderClaims map[string]string
// Body is the response body used by the EncodeIssueResponse.
Body []byte
}

// Headers creates and returns a set of http headers based on HeaderClaims.
// Any failures to add claims are silent and does not affect the response.
func (resp Response) Headers() http.Header {
headers := http.Header{}
for claimKey, headerName := range resp.HeaderClaims {
c, ok := resp.Claims[claimKey]
if !ok {
continue

Check warning on line 72 in token/endpoint.go

View check run for this annotation

Codecov / codecov/patch

token/endpoint.go#L72

Added line #L72 was not covered by tests
}

switch v := c.(type) {
case time.Time:
headers.Add(headerName, v.Format(http.TimeFormat))

Check warning on line 77 in token/endpoint.go

View check run for this annotation

Codecov / codecov/patch

token/endpoint.go#L76-L77

Added lines #L76 - L77 were not covered by tests
case string:
headers.Add(headerName, v)
}
}

return headers
}
38 changes: 23 additions & 15 deletions token/endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,22 @@ func testNewIssueEndpointSuccess(t *testing.T) {
assert = assert.New(t)
require = require.New(t)

factory = new(mockFactory)
request = NewRequest()
endpoint = NewIssueEndpoint(factory)
factory = new(mockFactory)
request = NewRequest()
expectedResp = Response{
Claims: make(map[string]interface{}, len(request.Claims)),
HeaderClaims: map[string]string{},
Body: []byte("test"),
}
endpoint = NewIssueEndpoint(factory, map[string]string{})
)

require.NotNil(endpoint)
factory.ExpectNewToken(context.Background(), request).Once().Return("test", error(nil))
token, err := endpoint(context.Background(), request)
assert.Equal("test", token)
factory.ExpectNewToken(context.Background(), request, map[string]interface{}{}).Once().Return("test", error(nil))
value, err := endpoint(context.Background(), request)
resp, ok := value.(Response)
require.True(ok)
assert.Equal(expectedResp, resp)
assert.NoError(err)

factory.AssertExpectations(t)
Expand All @@ -36,16 +43,17 @@ func testNewIssueEndpointFailure(t *testing.T) {
assert = assert.New(t)
require = require.New(t)

factory = new(mockFactory)
expectedErr = errors.New("expected")
request = NewRequest()
endpoint = NewIssueEndpoint(factory)
factory = new(mockFactory)
expectedErr = errors.New("expected")
request = NewRequest()
expectedResp = Response{}
endpoint = NewIssueEndpoint(factory, map[string]string{})
)

require.NotNil(endpoint)
factory.ExpectNewToken(context.Background(), request).Once().Return("", expectedErr)
token, actualErr := endpoint(context.Background(), request)
assert.Equal("", token)
factory.ExpectNewToken(context.Background(), request, map[string]interface{}{}).Once().Return("", expectedErr)
resp, actualErr := endpoint(context.Background(), request)
assert.Equal(expectedResp, resp)
assert.Equal(expectedErr, actualErr)

factory.AssertExpectations(t)
Expand All @@ -64,7 +72,7 @@ func testNewClaimsEndpointSuccess(t *testing.T) {
builder = new(mockClaimBuilder)
expectedClaims = map[string]interface{}{"key": "value"}
request = NewRequest()
endpoint = NewClaimsEndpoint(builder)
endpoint = NewClaimsEndpoint(builder, map[string]string{})
)

require.NotNil(endpoint)
Expand All @@ -88,7 +96,7 @@ func testNewClaimsEndpointFailure(t *testing.T) {
builder = new(mockClaimBuilder)
expectedErr = errors.New("expected")
request = NewRequest()
endpoint = NewClaimsEndpoint(builder)
endpoint = NewClaimsEndpoint(builder, map[string]string{})
)

require.NotNil(endpoint)
Expand Down
15 changes: 10 additions & 5 deletions token/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func NewRequest() *Request {
// Factory is a creation strategy for signed JWT tokens
type Factory interface {
// NewToken uses a Request to produce a signed JWT token
NewToken(context.Context, *Request) (string, error)
NewToken(context.Context, *Request, map[string]interface{}) (string, error)
}

type factory struct {
Expand All @@ -62,13 +62,18 @@ type factory struct {
pair atomic.Value
}

func (f *factory) NewToken(ctx context.Context, r *Request) (string, error) {
merged := make(map[string]interface{}, len(r.Claims))
if err := f.claimBuilder.AddClaims(ctx, r, merged); err != nil {
// NewToken returns a token based on a given request.
func (f *factory) NewToken(ctx context.Context, r *Request, claims map[string]interface{}) (string, error) {
// claims will are non nil when responses need to access them.
if claims == nil {
claims = make(map[string]interface{}, len(r.Claims))
}

if err := f.claimBuilder.AddClaims(ctx, r, claims); err != nil {
return "", err
}

token := jwt.NewWithClaims(f.method, jwt.MapClaims(merged))
token := jwt.NewWithClaims(f.method, jwt.MapClaims(claims))
pair := f.pair.Load().(key.Pair)
token.Header["kid"] = pair.KID()
return token.SignedString(pair.Sign())
Expand Down
40 changes: 37 additions & 3 deletions token/factory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/xmidt-org/themis/key"
"github.com/xmidt-org/themis/random"
"github.com/xmidt-org/themis/random/randomtest"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -48,7 +49,7 @@ func testNewFactoryInvalidKeyType(t *testing.T) {
assert.Error(err)
}

func testNewFactorySuccess(t *testing.T) {
func testNewFactoryWithoutReponseClaimsSuccess(t *testing.T) {
var (
assert = assert.New(t)
require = require.New(t)
Expand All @@ -72,13 +73,46 @@ func testNewFactorySuccess(t *testing.T) {
require.NoError(err)
require.NotNil(factory)

token, err := factory.NewToken(context.Background(), new(Request))
token, err := factory.NewToken(context.Background(), new(Request), nil)
require.NoError(err)
assert.True(len(token) > 0)
}

func testNewFactoryWithReponseClaimsSuccess(t *testing.T) {
var (
assert = assert.New(t)
require = require.New(t)
registry = key.NewRegistry(rand.Reader)
noncer = new(randomtest.Noncer)
expectedClaims = map[string]interface{}{"jti": "deadbeef"}
cb = ClaimBuilders{
nonceClaimBuilder{n: noncer},
}
)

noncer.ExpectNonce().Return("deadbeef", error(nil)).Once()
factory, err := NewFactory(Options{
Alg: "RS256",
Key: key.Descriptor{
Kid: "test",
Bits: 512,
},
Nonce: true,
}, cb, registry)

require.NoError(err)
require.NotNil(factory)

resp := Response{Claims: make(map[string]interface{})}
token, err := factory.NewToken(context.Background(), new(Request), resp.Claims)
require.NoError(err)
assert.True(len(token) > 0)
assert.Equal(expectedClaims, resp.Claims)
}

func TestNewFactory(t *testing.T) {
t.Run("InvalidAlg", testNewFactoryInvalidAlg)
t.Run("InvalidKeyType", testNewFactoryInvalidKeyType)
t.Run("Success", testNewFactorySuccess)
t.Run("WithoutReponseClaims Success", testNewFactoryWithoutReponseClaimsSuccess)
t.Run("WithReponseClaims Success", testNewFactoryWithReponseClaimsSuccess)
}
57 changes: 54 additions & 3 deletions token/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,70 @@ import (
"github.com/stretchr/testify/require"
)

func TestNewIssueHandler(t *testing.T) {
func TestNewIssueHandlerWithoutClaimHeaders(t *testing.T) {
var (
assert = assert.New(t)
require = require.New(t)

endpoint = endpoint.Endpoint(func(_ context.Context, v interface{}) (interface{}, error) {
var output bytes.Buffer
var (
resp = Response{}
output bytes.Buffer
)
output.WriteString("endpoint=run")
for key, value := range v.(*Request).Claims {
output.WriteRune(',')
fmt.Fprintf(&output, "%s=%s", key, value)
}

return output.String(), nil
resp.Body = output.Bytes()

return resp, nil
})

builders = RequestBuilders{
RequestBuilderFunc(func(original *http.Request, r *Request) error {
r.Claims["claim"] = original.Header.Get("Claim")
return nil
}),
}

handler = NewIssueHandler(endpoint, builders)
response = httptest.NewRecorder()
request = httptest.NewRequest("POST", "/", nil)
)

require.NotNil(handler)
request.Header.Set("Claim", "fromHeader")
handler.ServeHTTP(response, request)
assert.Equal("application/jose", response.Header().Get("Content-Type"))
assert.Empty(response.Header().Get("claim"))
assert.Equal("endpoint=run,claim=fromHeader", response.Body.String())
}

func TestNewIssueHandlerWithClaimHeaders(t *testing.T) {
var (
assert = assert.New(t)
require = require.New(t)

endpoint = endpoint.Endpoint(func(_ context.Context, v interface{}) (interface{}, error) {
var (
resp = Response{
Claims: make(map[string]interface{}),
HeaderClaims: map[string]string{"claim": "HeaderClaim"},
}
output bytes.Buffer
)
output.WriteString("endpoint=run")
for key, value := range v.(*Request).Claims {
output.WriteRune(',')
fmt.Fprintf(&output, "%s=%s", key, value)
resp.Claims[key] = value
}

resp.Body = output.Bytes()

return resp, nil
})

builders = RequestBuilders{
Expand All @@ -47,6 +97,7 @@ func TestNewIssueHandler(t *testing.T) {
request.Header.Set("Claim", "fromHeader")
handler.ServeHTTP(response, request)
assert.Equal("application/jose", response.Header().Get("Content-Type"))
assert.Equal("fromHeader", response.Header().Get("HeaderClaim"))
assert.Equal("endpoint=run,claim=fromHeader", response.Body.String())
}

Expand Down
8 changes: 4 additions & 4 deletions token/mocks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ type mockFactory struct {
mock.Mock
}

func (m *mockFactory) NewToken(ctx context.Context, r *Request) (string, error) {
arguments := m.Called(ctx, r)
func (m *mockFactory) NewToken(ctx context.Context, r *Request, claims map[string]interface{}) (string, error) {
arguments := m.Called(ctx, r, claims)
return arguments.String(0), arguments.Error(1)
}

func (m *mockFactory) ExpectNewToken(ctx context.Context, r *Request) *mock.Call {
return m.On("NewToken", ctx, r)
func (m *mockFactory) ExpectNewToken(ctx context.Context, r *Request, claims map[string]interface{}) *mock.Call {
return m.On("NewToken", ctx, r, claims)
}

type mockClaimBuilder struct {
Expand Down
5 changes: 5 additions & 0 deletions token/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ type Options struct {
// or statically from configuration. For special processing around the partner id, set the PartnerID field.
Claims []Value

// HeaderClaims is an optional map of claims-to-headers, where each claim will be attempted to be added as response header.
//
// Any failures to add claims are silent and does not affect the response.
HeaderClaims map[string]string

// Metadata describes non-claim data, which can be statically configured or supplied via a request
Metadata []Value

Expand Down
Loading

0 comments on commit 51bda72

Please sign in to comment.