diff --git a/acquire/acquire.go b/acquire/acquire.go index 20392f3..71355fd 100644 --- a/acquire/acquire.go +++ b/acquire/acquire.go @@ -9,7 +9,7 @@ import ( ) type Acquirer interface { - AddAuth(*http.Request) + AddAuth(*http.Request) error Acquire() (string, error) ParseToken([]byte) (string, error) ParseExpiration([]byte) (time.Time, error) diff --git a/chrysom/basicClient.go b/chrysom/basicClient.go index 720de88..fc16af7 100644 --- a/chrysom/basicClient.go +++ b/chrysom/basicClient.go @@ -83,19 +83,18 @@ type Items []model.Item // NewBasicClient creates a new BasicClient that can be used to // make requests to Argus. -func NewBasicClient(config BasicClientConfig) (*BasicClient, error) { +func NewBasicClient(config BasicClientConfig, auth acquire.Acquirer) (*BasicClient, error) { err := validateBasicConfig(&config) if err != nil { return nil, err } - tokenAcquirer, err := buildTokenAcquirer(config.Auth) if err != nil { return nil, err } clientStore := &BasicClient{ client: config.HTTPClient, - auth: tokenAcquirer, + auth: auth, bucket: config.Bucket, storeBaseURL: config.Address + storeAPIPath, } @@ -200,7 +199,7 @@ func (c *BasicClient) sendRequest(ctx context.Context, owner, method, url string if err != nil { return response{}, fmt.Errorf(errWrappedFmt, errNewRequestFailure, err.Error()) } - err = acquire.AddAuth(r, c.auth) + err = c.auth.AddAuth(r) if err != nil { return response{}, fmt.Errorf(errWrappedFmt, ErrAuthAcquirerFailure, err.Error()) } @@ -224,10 +223,6 @@ func (c *BasicClient) sendRequest(ctx context.Context, owner, method, url string return sqResp, nil } -func isEmpty(options acquire.RemoteBearerTokenAcquirerOptions) bool { - return len(options.AuthURL) < 1 || options.Buffer == 0 || options.Timeout == 0 -} - // translateNonSuccessStatusCode returns as specific error // for known Argus status codes. func translateNonSuccessStatusCode(code int) error { @@ -241,15 +236,6 @@ func translateNonSuccessStatusCode(code int) error { } } -func buildTokenAcquirer(auth Auth) (acquire.Acquirer, error) { - if !isEmpty(auth.JWT) { - return acquire.NewRemoteBearerTokenAcquirer(auth.JWT) - } else if len(auth.Basic) > 0 { - return acquire.NewFixedAuthAcquirer(auth.Basic) - } - return &acquire.DefaultAcquirer{}, nil -} - func validateBasicConfig(config *BasicClientConfig) error { if config.Address == "" { return ErrAddressEmpty diff --git a/go.mod b/go.mod index 2ed92be..86d0c30 100644 --- a/go.mod +++ b/go.mod @@ -7,13 +7,11 @@ toolchain go1.22.9 require ( github.com/aws/aws-sdk-go v1.54.19 github.com/go-kit/kit v0.13.0 - github.com/golang-jwt/jwt v3.2.2+incompatible github.com/lestrrat-go/jwx/v2 v2.1.2 github.com/prometheus/client_golang v1.19.1 github.com/prometheus/client_model v0.6.1 github.com/spf13/cast v1.6.0 github.com/stretchr/testify v1.9.0 - github.com/xmidt-org/bascule v0.11.6 github.com/xmidt-org/httpaux v0.4.0 github.com/xmidt-org/sallust v0.2.2 github.com/xmidt-org/touchstone v0.1.5 diff --git a/go.sum b/go.sum index fa68183..0fbe9c1 100644 --- a/go.sum +++ b/go.sum @@ -19,8 +19,6 @@ github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs= github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= -github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= -github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= @@ -65,8 +63,6 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -github.com/xmidt-org/bascule v0.11.6 h1:i46FAI97XPMt3OKraiNyKa+mt36AhLO8iuInAypXKNM= -github.com/xmidt-org/bascule v0.11.6/go.mod h1:BXb5PEm/tjqdiEGsd+phm+fItMJx+Huv6LTCEU/zTzg= github.com/xmidt-org/httpaux v0.4.0 h1:cAL/MzIBpSsv4xZZeq/Eu1J5M3vfNe49xr41mP3COKU= github.com/xmidt-org/httpaux v0.4.0/go.mod h1:UypqZwuZV1nn8D6+K1JDb+im9IZrLNg/2oO/Bgiybxc= github.com/xmidt-org/sallust v0.2.2 h1:MrINLEr7cMj6ENx/O76fvpfd5LNGYnk7OipZAGXPYA0= diff --git a/jwtAcquireParser.go b/jwtAcquireParser.go deleted file mode 100644 index a27ed32..0000000 --- a/jwtAcquireParser.go +++ /dev/null @@ -1,76 +0,0 @@ -// SPDX-FileCopyrightText: 2022 Comcast Cable Communications Management, LLC -// SPDX-License-Identifier: Apache-2.0 - -package ancla - -import ( - "errors" - "time" - - "github.com/golang-jwt/jwt" - "github.com/spf13/cast" - "github.com/xmidt-org/bascule/acquire" -) - -type jwtAcquireParserType string - -const ( - simpleType jwtAcquireParserType = "simple" - rawType jwtAcquireParserType = "raw" -) - -var ( - errMissingExpClaim = errors.New("missing exp claim in jwt") - errUnexpectedCasting = errors.New("unexpected casting error") -) - -type jwtAcquireParser struct { - token acquire.TokenParser - expiration acquire.ParseExpiration -} - -func rawTokenParser(data []byte) (string, error) { - return string(data), nil -} - -func rawTokenExpirationParser(data []byte) (time.Time, error) { - p := jwt.Parser{SkipClaimsValidation: true} - token, _, err := p.ParseUnverified(string(data), jwt.MapClaims{}) - if err != nil { - return time.Time{}, err - } - claims, ok := token.Claims.(jwt.MapClaims) - if !ok { - return time.Time{}, errUnexpectedCasting - } - expVal, ok := claims["exp"] - if !ok { - return time.Time{}, errMissingExpClaim - } - - exp, err := cast.ToInt64E(expVal) - if err != nil { - return time.Time{}, err - } - return time.Unix(exp, 0), nil -} - -func newJWTAcquireParser(pType jwtAcquireParserType) (jwtAcquireParser, error) { - if pType == "" { - pType = simpleType - } - if pType != simpleType && pType != rawType { - return jwtAcquireParser{}, errors.New("only 'simple' or 'raw' are supported as jwt acquire parser types") - } - // nil defaults are fine (bascule/acquire will use the simple - // default parsers internally). - var ( - tokenParser acquire.TokenParser - expirationParser acquire.ParseExpiration - ) - if pType == rawType { - tokenParser = rawTokenParser - expirationParser = rawTokenExpirationParser - } - return jwtAcquireParser{expiration: expirationParser, token: tokenParser}, nil -} diff --git a/jwtAcquireParser_test.go b/jwtAcquireParser_test.go deleted file mode 100644 index fed2a25..0000000 --- a/jwtAcquireParser_test.go +++ /dev/null @@ -1,94 +0,0 @@ -// SPDX-FileCopyrightText: 2022 Comcast Cable Communications Management, LLC -// SPDX-License-Identifier: Apache-2.0 - -package ancla - -import ( - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -func TestNewJWTAcquireParser(t *testing.T) { - tcs := []struct { - Description string - ParserType jwtAcquireParserType - ShouldFail bool - }{ - { - Description: "Default", - }, - { - Description: "Invalid type", - ParserType: "advanced", - ShouldFail: true, - }, - { - Description: "Simple", - ParserType: simpleType, - }, - { - Description: "Raw", - ParserType: rawType, - }, - } - - for _, tc := range tcs { - t.Run(tc.Description, func(t *testing.T) { - assert := assert.New(t) - p, err := newJWTAcquireParser(tc.ParserType) - if tc.ShouldFail { - assert.NotNil(err) - assert.Nil(p.expiration) - assert.Nil(p.token) - } else { - assert.Nil(err) - if tc.ParserType == rawType { - assert.NotNil(p.expiration) - assert.NotNil(p.token) - } - } - }) - } -} - -func TestRawTokenParser(t *testing.T) { - assert := assert.New(t) - payload := []byte("eyJhbGciOiJSUzI1NiIsImtpZCI6ImRldmVsb3BtZW50IiwidHlwIjoiSldUIn0.eyJhbGxvd2VkUmVzb3VyY2VzIjp7ImFsbG93ZWRQYXJ0bmVycyI6WyJjb21jYXN0Il19LCJhdWQiOiJYTWlEVCIsImNhcGFiaWxpdGllcyI6WyJ4MTppc3N1ZXI6dGVzdDouKjphbGwiLCJ4MTppc3N1ZXI6dWk6YWxsIl0sImV4cCI6MTYyMjE1Nzk4MSwiaWF0IjoxNjIyMDcxNTgxLCJpc3MiOiJkZXZlbG9wbWVudCIsImp0aSI6ImN4ZmkybTZDWnJjaFNoZ1Nzdi1EM3ciLCJuYmYiOjE2MjIwNzE1NjYsInBhcnRuZXItaWQiOiJjb21jYXN0Iiwic3ViIjoiY2xpZW50LXN1cHBsaWVkIiwidHJ1c3QiOjEwMDB9.7QzRWJgxGs1cEZunMOewYCnEDiq2CTDh5R5F47PYhkMVb2KxSf06PRRGN-rQSWPhhBbev1fGgu63mr3yp_VDmdVvHR2oYiKyxP2skJTSzfQmiRyLMYY5LcLn3BObyQxU8EnLhnqGIjpORW0L5Dd4QsaZmXRnkC73yGnJx4XCx0I") - token, err := rawTokenParser(payload) - assert.Equal(string(payload), token) - assert.Nil(err) -} - -func TestRawExpirationParser(t *testing.T) { - tcs := []struct { - Description string - Payload []byte - ShouldFail bool - ExpectedTime time.Time - }{ - { - Description: "Not a JWT", - Payload: []byte("xyz==abcNotAJWT"), - ShouldFail: true, - }, - { - Description: "A jwt", - Payload: []byte("eyJhbGciOiJSUzI1NiIsImtpZCI6ImRldmVsb3BtZW50IiwidHlwIjoiSldUIn0.eyJhbGxvd2VkUmVzb3VyY2VzIjp7ImFsbG93ZWRQYXJ0bmVycyI6WyJjb21jYXN0Il19LCJhdWQiOiJYTWlEVCIsImNhcGFiaWxpdGllcyI6WyJ4MTppc3N1ZXI6dGVzdDouKjphbGwiLCJ4MTppc3N1ZXI6dWk6YWxsIl0sImV4cCI6MTYyMjE1Nzk4MSwiaWF0IjoxNjIyMDcxNTgxLCJpc3MiOiJkZXZlbG9wbWVudCIsImp0aSI6ImN4ZmkybTZDWnJjaFNoZ1Nzdi1EM3ciLCJuYmYiOjE2MjIwNzE1NjYsInBhcnRuZXItaWQiOiJjb21jYXN0Iiwic3ViIjoiY2xpZW50LXN1cHBsaWVkIiwidHJ1c3QiOjEwMDB9.7QzRWJgxGs1cEZunMOewYCnEDiq2CTDh5R5F47PYhkMVb2KxSf06PRRGN-rQSWPhhBbev1fGgu63mr3yp_VDmdVvHR2oYiKyxP2skJTSzfQmiRyLMYY5LcLn3BObyQxU8EnLhnqGIjpORW0L5Dd4QsaZmXRnkC73yGnJx4XCx0I"), - ExpectedTime: time.Unix(1622157981, 0), - }, - } - - for _, tc := range tcs { - assert := assert.New(t) - exp, err := rawTokenExpirationParser(tc.Payload) - if tc.ShouldFail { - assert.NotNil(err) - assert.Empty(exp) - } else { - assert.Nil(err) - assert.Equal(tc.ExpectedTime, exp) - } - } -} diff --git a/service.go b/service.go index b0edad3..ac160a0 100644 --- a/service.go +++ b/service.go @@ -10,6 +10,7 @@ import ( "net/http" "time" + "github.com/xmidt-org/ancla/acquire" "github.com/xmidt-org/ancla/chrysom" "github.com/xmidt-org/sallust" "go.uber.org/fx" @@ -47,7 +48,7 @@ type Config struct { // Simple: parser assumes token payloads have the following structure: https://github.com/xmidt-org/bascule/blob/c011b128d6b95fa8358228535c63d1945347adaa/acquire/bearer.go#L77 // Raw: parser assumes all of the token payload == JWT token // (Optional). Defaults to 'simple' - JWTParserType jwtAcquireParserType + JWTParserType string // DisablePartnerIDs, if true, will allow webhooks to register without // checking the validity of the partnerIDs in the request @@ -68,9 +69,8 @@ type ClientService struct { } // NewService builds the Argus client service from the given configuration. -func NewService(cfg Config) (*ClientService, error) { - prepArgusBasicClientConfig(&cfg) - basic, err := chrysom.NewBasicClient(cfg.BasicClientConfig) +func NewService(cfg Config, auth acquire.Acquirer) (*ClientService, error) { + basic, err := chrysom.NewBasicClient(cfg.BasicClientConfig, auth) if err != nil { return nil, fmt.Errorf("failed to create chrysom basic client: %v", err) } @@ -133,16 +133,6 @@ func (s *ClientService) GetAll(ctx context.Context) ([]Register, error) { return iws, nil } -func prepArgusBasicClientConfig(cfg *Config) error { - p, err := newJWTAcquireParser(cfg.JWTParserType) - if err != nil { - return err - } - cfg.BasicClientConfig.Auth.JWT.GetToken = p.token - cfg.BasicClientConfig.Auth.JWT.GetExpiration = p.expiration - return nil -} - func prepArgusListenerConfig(cfg *chrysom.ListenerConfig, metrics chrysom.Measures, watches ...Watch) { watches = append(watches, webhookListSizeWatch(metrics.WebhookListSizeGauge)) cfg.Listener = chrysom.ListenerFunc(func(ctx context.Context, items chrysom.Items) { @@ -160,14 +150,16 @@ func prepArgusListenerConfig(cfg *chrysom.ListenerConfig, metrics chrysom.Measur type ServiceIn struct { fx.In + Config Config Client *http.Client + Auth acquire.Acquirer } func ProvideService() fx.Option { return fx.Provide( func(in ServiceIn) (*ClientService, error) { - svc, err := NewService(in.Config) + svc, err := NewService(in.Config, in.Auth) if err != nil { return nil, errors.Join(errFailedConfig, err) }