From 51bda727f1590b083b5c048aac49a3d3792bcbe8 Mon Sep 17 00:00:00 2001 From: Owen Cabalceta Date: Wed, 20 Nov 2024 09:33:48 -0500 Subject: [PATCH] feat: allow claims to be used for response headers - allow token claims to be used for response headers - add optional config to determine which claims will be added as what header - update tests --- .gitignore | 1 + token/claimBuilder.go | 6 +++- token/endpoint.go | 68 +++++++++++++++++++++++++++++++++++++---- token/endpoint_test.go | 38 ++++++++++++++--------- token/factory.go | 15 ++++++--- token/factory_test.go | 40 ++++++++++++++++++++++-- token/handler_test.go | 57 ++++++++++++++++++++++++++++++++-- token/mocks_test.go | 8 ++--- token/options.go | 5 +++ token/transport.go | 13 ++++++-- token/transport_test.go | 4 +-- token/unmarshal.go | 4 +-- 12 files changed, 215 insertions(+), 44 deletions(-) diff --git a/.gitignore b/.gitignore index f11fba0..c964e8b 100644 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,7 @@ report.json # for VSCode .vscode/ .dev/ +__debug_bin* # for releases .ignore/ diff --git a/token/claimBuilder.go b/token/claimBuilder.go index 3e1c668..743807f 100644 --- a/token/claimBuilder.go +++ b/token/claimBuilder.go @@ -29,6 +29,10 @@ var ( ErrMissingKey = errors.New("A key is required for all claims and metadata values") ) +const ( + jwtExpName = "exp" +) + // ClaimBuilder is a strategy for building token claims, given a token Request type ClaimBuilder interface { AddClaims(context.Context, *Request, map[string]interface{}) error @@ -88,7 +92,7 @@ func (tc *timeClaimBuilder) AddClaims(_ context.Context, r *Request, target map[ target["iat"] = now.Unix() if tc.duration > 0 { - target["exp"] = now.Add(tc.duration).Unix() + target[jwtExpName] = now.Add(tc.duration).Unix() } if !tc.disableNotBefore { diff --git a/token/endpoint.go b/token/endpoint.go index 06895ad..15cf707 100644 --- a/token/endpoint.go +++ b/token/endpoint.go @@ -4,25 +4,81 @@ package token import ( "context" + "errors" + "net/http" + "time" "github.com/go-kit/kit/endpoint" ) +const ( + JWTExpHeader = "Expires" +) + +var ( + ErrJWTExpType = errors.New("expected jwt to be of type time.Time") +) + // NewIssueEndpoint returns a go-kit endpoint for a token factory's NewToken method -func NewIssueEndpoint(f Factory) endpoint.Endpoint { +func NewIssueEndpoint(f Factory, headerClaims map[string]string) endpoint.Endpoint { return func(ctx context.Context, v interface{}) (interface{}, error) { - return f.NewToken(ctx, v.(*Request)) + req := v.(*Request) + resp := Response{ + Claims: make(map[string]interface{}, len(req.Claims)), + HeaderClaims: headerClaims, + } + token, err := f.NewToken(ctx, req, resp.Claims) + if err != nil { + return Response{}, err + } + + resp.Body = []byte(token) + + return resp, err } } // NewClaimsEndpoint returns a go-kit endpoint that returns just the claims -func NewClaimsEndpoint(cb ClaimBuilder) endpoint.Endpoint { +func NewClaimsEndpoint(cb ClaimBuilder, headerClaims map[string]string) endpoint.Endpoint { return func(ctx context.Context, v interface{}) (interface{}, error) { - merged := make(map[string]interface{}) - if err := cb.AddClaims(ctx, v.(*Request), merged); err != nil { + resp := Response{ + Claims: make(map[string]interface{}), + HeaderClaims: headerClaims, + } + if err := cb.AddClaims(ctx, v.(*Request), resp.Claims); err != nil { return nil, err } - return merged, nil + return resp.Claims, nil } } + +type Response struct { + // Claims is the set of token claims. + Claims map[string]interface{} + // HeaderClaims is a map of claims-to-headers, where each claim will be attempted to be added as response header. + HeaderClaims map[string]string + // Body is the response body used by the EncodeIssueResponse. + Body []byte +} + +// Headers creates and returns a set of http headers based on HeaderClaims. +// Any failures to add claims are silent and does not affect the response. +func (resp Response) Headers() http.Header { + headers := http.Header{} + for claimKey, headerName := range resp.HeaderClaims { + c, ok := resp.Claims[claimKey] + if !ok { + continue + } + + switch v := c.(type) { + case time.Time: + headers.Add(headerName, v.Format(http.TimeFormat)) + case string: + headers.Add(headerName, v) + } + } + + return headers +} diff --git a/token/endpoint_test.go b/token/endpoint_test.go index d21164f..e30a125 100644 --- a/token/endpoint_test.go +++ b/token/endpoint_test.go @@ -17,15 +17,22 @@ func testNewIssueEndpointSuccess(t *testing.T) { assert = assert.New(t) require = require.New(t) - factory = new(mockFactory) - request = NewRequest() - endpoint = NewIssueEndpoint(factory) + factory = new(mockFactory) + request = NewRequest() + expectedResp = Response{ + Claims: make(map[string]interface{}, len(request.Claims)), + HeaderClaims: map[string]string{}, + Body: []byte("test"), + } + endpoint = NewIssueEndpoint(factory, map[string]string{}) ) require.NotNil(endpoint) - factory.ExpectNewToken(context.Background(), request).Once().Return("test", error(nil)) - token, err := endpoint(context.Background(), request) - assert.Equal("test", token) + factory.ExpectNewToken(context.Background(), request, map[string]interface{}{}).Once().Return("test", error(nil)) + value, err := endpoint(context.Background(), request) + resp, ok := value.(Response) + require.True(ok) + assert.Equal(expectedResp, resp) assert.NoError(err) factory.AssertExpectations(t) @@ -36,16 +43,17 @@ func testNewIssueEndpointFailure(t *testing.T) { assert = assert.New(t) require = require.New(t) - factory = new(mockFactory) - expectedErr = errors.New("expected") - request = NewRequest() - endpoint = NewIssueEndpoint(factory) + factory = new(mockFactory) + expectedErr = errors.New("expected") + request = NewRequest() + expectedResp = Response{} + endpoint = NewIssueEndpoint(factory, map[string]string{}) ) require.NotNil(endpoint) - factory.ExpectNewToken(context.Background(), request).Once().Return("", expectedErr) - token, actualErr := endpoint(context.Background(), request) - assert.Equal("", token) + factory.ExpectNewToken(context.Background(), request, map[string]interface{}{}).Once().Return("", expectedErr) + resp, actualErr := endpoint(context.Background(), request) + assert.Equal(expectedResp, resp) assert.Equal(expectedErr, actualErr) factory.AssertExpectations(t) @@ -64,7 +72,7 @@ func testNewClaimsEndpointSuccess(t *testing.T) { builder = new(mockClaimBuilder) expectedClaims = map[string]interface{}{"key": "value"} request = NewRequest() - endpoint = NewClaimsEndpoint(builder) + endpoint = NewClaimsEndpoint(builder, map[string]string{}) ) require.NotNil(endpoint) @@ -88,7 +96,7 @@ func testNewClaimsEndpointFailure(t *testing.T) { builder = new(mockClaimBuilder) expectedErr = errors.New("expected") request = NewRequest() - endpoint = NewClaimsEndpoint(builder) + endpoint = NewClaimsEndpoint(builder, map[string]string{}) ) require.NotNil(endpoint) diff --git a/token/factory.go b/token/factory.go index 9ee6c6e..ec4bda7 100644 --- a/token/factory.go +++ b/token/factory.go @@ -51,7 +51,7 @@ func NewRequest() *Request { // Factory is a creation strategy for signed JWT tokens type Factory interface { // NewToken uses a Request to produce a signed JWT token - NewToken(context.Context, *Request) (string, error) + NewToken(context.Context, *Request, map[string]interface{}) (string, error) } type factory struct { @@ -62,13 +62,18 @@ type factory struct { pair atomic.Value } -func (f *factory) NewToken(ctx context.Context, r *Request) (string, error) { - merged := make(map[string]interface{}, len(r.Claims)) - if err := f.claimBuilder.AddClaims(ctx, r, merged); err != nil { +// NewToken returns a token based on a given request. +func (f *factory) NewToken(ctx context.Context, r *Request, claims map[string]interface{}) (string, error) { + // claims will are non nil when responses need to access them. + if claims == nil { + claims = make(map[string]interface{}, len(r.Claims)) + } + + if err := f.claimBuilder.AddClaims(ctx, r, claims); err != nil { return "", err } - token := jwt.NewWithClaims(f.method, jwt.MapClaims(merged)) + token := jwt.NewWithClaims(f.method, jwt.MapClaims(claims)) pair := f.pair.Load().(key.Pair) token.Header["kid"] = pair.KID() return token.SignedString(pair.Sign()) diff --git a/token/factory_test.go b/token/factory_test.go index 56b6851..79f9d1d 100644 --- a/token/factory_test.go +++ b/token/factory_test.go @@ -9,6 +9,7 @@ import ( "github.com/xmidt-org/themis/key" "github.com/xmidt-org/themis/random" + "github.com/xmidt-org/themis/random/randomtest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -48,7 +49,7 @@ func testNewFactoryInvalidKeyType(t *testing.T) { assert.Error(err) } -func testNewFactorySuccess(t *testing.T) { +func testNewFactoryWithoutReponseClaimsSuccess(t *testing.T) { var ( assert = assert.New(t) require = require.New(t) @@ -72,13 +73,46 @@ func testNewFactorySuccess(t *testing.T) { require.NoError(err) require.NotNil(factory) - token, err := factory.NewToken(context.Background(), new(Request)) + token, err := factory.NewToken(context.Background(), new(Request), nil) require.NoError(err) assert.True(len(token) > 0) } +func testNewFactoryWithReponseClaimsSuccess(t *testing.T) { + var ( + assert = assert.New(t) + require = require.New(t) + registry = key.NewRegistry(rand.Reader) + noncer = new(randomtest.Noncer) + expectedClaims = map[string]interface{}{"jti": "deadbeef"} + cb = ClaimBuilders{ + nonceClaimBuilder{n: noncer}, + } + ) + + noncer.ExpectNonce().Return("deadbeef", error(nil)).Once() + factory, err := NewFactory(Options{ + Alg: "RS256", + Key: key.Descriptor{ + Kid: "test", + Bits: 512, + }, + Nonce: true, + }, cb, registry) + + require.NoError(err) + require.NotNil(factory) + + resp := Response{Claims: make(map[string]interface{})} + token, err := factory.NewToken(context.Background(), new(Request), resp.Claims) + require.NoError(err) + assert.True(len(token) > 0) + assert.Equal(expectedClaims, resp.Claims) +} + func TestNewFactory(t *testing.T) { t.Run("InvalidAlg", testNewFactoryInvalidAlg) t.Run("InvalidKeyType", testNewFactoryInvalidKeyType) - t.Run("Success", testNewFactorySuccess) + t.Run("WithoutReponseClaims Success", testNewFactoryWithoutReponseClaimsSuccess) + t.Run("WithReponseClaims Success", testNewFactoryWithReponseClaimsSuccess) } diff --git a/token/handler_test.go b/token/handler_test.go index 5f452c4..ba742f9 100644 --- a/token/handler_test.go +++ b/token/handler_test.go @@ -15,20 +15,70 @@ import ( "github.com/stretchr/testify/require" ) -func TestNewIssueHandler(t *testing.T) { +func TestNewIssueHandlerWithoutClaimHeaders(t *testing.T) { var ( assert = assert.New(t) require = require.New(t) endpoint = endpoint.Endpoint(func(_ context.Context, v interface{}) (interface{}, error) { - var output bytes.Buffer + var ( + resp = Response{} + output bytes.Buffer + ) output.WriteString("endpoint=run") for key, value := range v.(*Request).Claims { output.WriteRune(',') fmt.Fprintf(&output, "%s=%s", key, value) } - return output.String(), nil + resp.Body = output.Bytes() + + return resp, nil + }) + + builders = RequestBuilders{ + RequestBuilderFunc(func(original *http.Request, r *Request) error { + r.Claims["claim"] = original.Header.Get("Claim") + return nil + }), + } + + handler = NewIssueHandler(endpoint, builders) + response = httptest.NewRecorder() + request = httptest.NewRequest("POST", "/", nil) + ) + + require.NotNil(handler) + request.Header.Set("Claim", "fromHeader") + handler.ServeHTTP(response, request) + assert.Equal("application/jose", response.Header().Get("Content-Type")) + assert.Empty(response.Header().Get("claim")) + assert.Equal("endpoint=run,claim=fromHeader", response.Body.String()) +} + +func TestNewIssueHandlerWithClaimHeaders(t *testing.T) { + var ( + assert = assert.New(t) + require = require.New(t) + + endpoint = endpoint.Endpoint(func(_ context.Context, v interface{}) (interface{}, error) { + var ( + resp = Response{ + Claims: make(map[string]interface{}), + HeaderClaims: map[string]string{"claim": "HeaderClaim"}, + } + output bytes.Buffer + ) + output.WriteString("endpoint=run") + for key, value := range v.(*Request).Claims { + output.WriteRune(',') + fmt.Fprintf(&output, "%s=%s", key, value) + resp.Claims[key] = value + } + + resp.Body = output.Bytes() + + return resp, nil }) builders = RequestBuilders{ @@ -47,6 +97,7 @@ func TestNewIssueHandler(t *testing.T) { request.Header.Set("Claim", "fromHeader") handler.ServeHTTP(response, request) assert.Equal("application/jose", response.Header().Get("Content-Type")) + assert.Equal("fromHeader", response.Header().Get("HeaderClaim")) assert.Equal("endpoint=run,claim=fromHeader", response.Body.String()) } diff --git a/token/mocks_test.go b/token/mocks_test.go index 0968ad2..2504c9d 100644 --- a/token/mocks_test.go +++ b/token/mocks_test.go @@ -12,13 +12,13 @@ type mockFactory struct { mock.Mock } -func (m *mockFactory) NewToken(ctx context.Context, r *Request) (string, error) { - arguments := m.Called(ctx, r) +func (m *mockFactory) NewToken(ctx context.Context, r *Request, claims map[string]interface{}) (string, error) { + arguments := m.Called(ctx, r, claims) return arguments.String(0), arguments.Error(1) } -func (m *mockFactory) ExpectNewToken(ctx context.Context, r *Request) *mock.Call { - return m.On("NewToken", ctx, r) +func (m *mockFactory) ExpectNewToken(ctx context.Context, r *Request, claims map[string]interface{}) *mock.Call { + return m.On("NewToken", ctx, r, claims) } type mockClaimBuilder struct { diff --git a/token/options.go b/token/options.go index 01ac2d0..cd80eb1 100644 --- a/token/options.go +++ b/token/options.go @@ -108,6 +108,11 @@ type Options struct { // or statically from configuration. For special processing around the partner id, set the PartnerID field. Claims []Value + // HeaderClaims is an optional map of claims-to-headers, where each claim will be attempted to be added as response header. + // + // Any failures to add claims are silent and does not affect the response. + HeaderClaims map[string]string + // Metadata describes non-claim data, which can be statically configured or supplied via a request Metadata []Value diff --git a/token/transport.go b/token/transport.go index 3832c9a..44bbf3d 100644 --- a/token/transport.go +++ b/token/transport.go @@ -328,9 +328,16 @@ func DecodeServerRequest(rb RequestBuilders) func(context.Context, *http.Request } } -func EncodeIssueResponse(_ context.Context, response http.ResponseWriter, value interface{}) error { - response.Header().Set("Content-Type", "application/jose") - _, err := response.Write([]byte(value.(string))) +func EncodeIssueResponse(_ context.Context, w http.ResponseWriter, value interface{}) error { + resp := value.(Response) + for k, values := range resp.Headers() { + for _, v := range values { + w.Header().Set(k, v) + } + } + + w.Header().Set("Content-Type", "application/jose") + _, err := w.Write(resp.Body) return err } diff --git a/token/transport_test.go b/token/transport_test.go index 6c96b8d..75228de 100644 --- a/token/transport_test.go +++ b/token/transport_test.go @@ -612,7 +612,7 @@ func TestEncodeIssueResponse(t *testing.T) { assert = assert.New(t) require = require.New(t) - expectedValue = "expected" + expectedValue = Response{Body: []byte("expected")} response = httptest.NewRecorder() ) @@ -621,7 +621,7 @@ func TestEncodeIssueResponse(t *testing.T) { ) assert.Equal("application/jose", response.Header().Get("Content-Type")) - assert.Equal(expectedValue, response.Body.String()) + assert.Equal(expectedValue.Body, response.Body.Bytes()) } func testDecodeRemoteClaimsResponseSuccess(t *testing.T) { diff --git a/token/unmarshal.go b/token/unmarshal.go index 93b622f..3455f69 100644 --- a/token/unmarshal.go +++ b/token/unmarshal.go @@ -58,11 +58,11 @@ func Unmarshal(configKey string, b ...RequestBuilder) func(TokenIn) (TokenOut, e ClaimBuilder: cb, Factory: f, IssueHandler: NewIssueHandler( - NewIssueEndpoint(f), + NewIssueEndpoint(f, o.HeaderClaims), rb, ), ClaimsHandler: NewClaimsHandler( - NewClaimsEndpoint(cb), + NewClaimsEndpoint(cb, o.HeaderClaims), rb, ), }, nil