diff --git a/auth/cache.go b/auth/cache.go index 6f3fcb04c..a3446a1be 100644 --- a/auth/cache.go +++ b/auth/cache.go @@ -17,21 +17,28 @@ limitations under the License. package auth import ( - "sync" + "errors" "time" ) -var once sync.Once var cache Store +var ErrCacheAlreadyInitialized = errors.New("cache already initialized; cannot be re-initialized") +var ErrEmptyCache = errors.New("cannot initialize cache with an empty store") + // InitCache intializes the pacakge cache with the provided cache object. // Consumers that want automatic caching when using `GetRegistryAuthenticator()` // or `GetGitCredentials()` must call this before. It should only be called once, -// all subsequent calls will be a no-op. -func InitCache(s Store) { - once.Do(func() { - cache = s - }) +// all subsequent calls will return an error. +func InitCache(s Store) error { + if cache != nil { + return ErrCacheAlreadyInitialized + } + if s == nil { + return ErrEmptyCache + } + cache = s + return nil } // GetCache returns a handle to the package level cache. diff --git a/auth/git/credentials.go b/auth/git/credentials.go index 6a5384aa3..8eeab1d20 100644 --- a/auth/git/credentials.go +++ b/auth/git/credentials.go @@ -55,17 +55,23 @@ func (c *Credentials) ToSecretData() map[string][]byte { // GetCredentials returns authentication credentials for accessing the provided // Git repository. -// If caching is enabled and cacheKey is not blank, the credentials are cached -// according to the ttl advertised by the Git provider. +// The authentication credentials will be cached if `authOpts.CacheOptions.Key` +// is not blank and caching is enabled. Caching can be enabled by either calling +// `auth.InitCache()` or specifying a cache via `authOpts.CacheOptions.Cache`. +// The credentials are cached according to the ttl advertised by the registry +// provider. func GetCredentials(ctx context.Context, provider string, authOpts *auth.AuthOptions) (*Credentials, error) { var creds Credentials - cache := auth.GetCache() - if cache != nil && authOpts != nil && authOpts.CacheKey != "" { - val, found := cache.Get(authOpts.CacheKey) - if found { - creds = val.(Credentials) - return &creds, nil + var cache auth.Store + if authOpts != nil { + cache = authOpts.GetCache() + if cache != nil && authOpts.CacheOptions.Key != "" { + val, found := cache.Get(authOpts.CacheOptions.Key) + if found { + creds = val.(Credentials) + return &creds, nil + } } } @@ -134,8 +140,8 @@ func GetCredentials(ctx context.Context, provider string, authOpts *auth.AuthOpt return nil, nil } - if cache != nil && authOpts != nil && authOpts.CacheKey != "" { - if err := cache.Set(authOpts.CacheKey, creds, expiresIn); err != nil { + if cache != nil && authOpts != nil && authOpts.CacheOptions.Key != "" { + if err := cache.Set(authOpts.CacheOptions.Key, creds, expiresIn); err != nil { return nil, err } } diff --git a/auth/git/credentials_test.go b/auth/git/credentials_test.go index 46e9ae4a1..429987e3d 100644 --- a/auth/git/credentials_test.go +++ b/auth/git/credentials_test.go @@ -42,9 +42,8 @@ import ( func TestGetCredentials(t *testing.T) { expiresAt := time.Now().UTC().Add(time.Hour) - var s auth.Store - s = testutils.NewDummyCache() - auth.InitCache(s) + auth.InitCache(testutils.NewDummyCache()) + customCache := testutils.NewDummyCache() tests := []struct { name string @@ -60,7 +59,9 @@ func TestGetCredentials(t *testing.T) { name: "get credentials from github", provider: auth.ProviderGitHub, authOpts: &auth.AuthOptions{ - CacheKey: "github-123", + CacheOptions: auth.CacheOptions{ + Key: "github-123", + }, }, responseBody: `{ "token": "access-token", @@ -93,7 +94,60 @@ func TestGetCredentials(t *testing.T) { name: "get credentials from cache", provider: auth.ProviderGitHub, authOpts: &auth.AuthOptions{ - CacheKey: "github-123", + CacheOptions: auth.CacheOptions{ + Key: "github-123", + }, + }, + expectCacheHit: true, + wantCredentials: &Credentials{ + Username: GitHubAccessTokenUsername, + Password: "access-token", + }, + }, + { + name: "get credentials from github with local cache", + provider: auth.ProviderGitHub, + authOpts: &auth.AuthOptions{ + CacheOptions: auth.CacheOptions{ + Key: "github-local-123", + Cache: customCache, + }, + }, + responseBody: `{ + "token": "access-token", + "expires_at": "2029-11-10T23:00:00Z" +}`, + beforeFunc: func(t *WithT, authOpts *auth.AuthOptions, serverURL string) { + pk, err := createPrivateKey() + t.Expect(err).ToNot(HaveOccurred()) + authOpts.Secret = &corev1.Secret{ + Data: map[string][]byte{ + github.ApiURLKey: []byte(serverURL), + github.AppIDKey: []byte("127"), + github.AppInstallationIDKey: []byte("300"), + github.AppPkKey: pk, + }, + } + }, + afterFunc: func(t *WithT, cache auth.Store, creds Credentials) { + val, ok := cache.Get("github-local-123") + t.Expect(ok).To(BeTrue()) + credentials := val.(Credentials) + t.Expect(credentials).To(Equal(creds)) + }, + wantCredentials: &Credentials{ + Username: GitHubAccessTokenUsername, + Password: "access-token", + }, + }, + { + name: "get credentials from local cache", + provider: auth.ProviderGitHub, + authOpts: &auth.AuthOptions{ + CacheOptions: auth.CacheOptions{ + Key: "github-local-123", + Cache: customCache, + }, }, expectCacheHit: true, wantCredentials: &Credentials{ @@ -130,7 +184,9 @@ func TestGetCredentials(t *testing.T) { name: "get credentials from azure", provider: auth.ProviderAzure, authOpts: &auth.AuthOptions{ - CacheKey: "azure-123", + CacheOptions: auth.CacheOptions{ + Key: "azure-123", + }, ProviderOptions: auth.ProviderOptions{ AzureOpts: []azure.ProviderOptFunc{ azure.WithCredential(&azure.FakeTokenCredential{ @@ -154,7 +210,9 @@ func TestGetCredentials(t *testing.T) { name: "get credentials from gcp", provider: auth.ProviderGCP, authOpts: &auth.AuthOptions{ - CacheKey: "gcp-123", + CacheOptions: auth.CacheOptions{ + Key: "gcp-123", + }, }, responseBody: `{ "access_token": "access-token", @@ -223,7 +281,7 @@ func TestGetCredentials(t *testing.T) { g.Expect(*creds).To(Equal(*tt.wantCredentials)) if tt.afterFunc != nil { - tt.afterFunc(g, s, *creds) + tt.afterFunc(g, tt.authOpts.GetCache(), *creds) } if tt.expectCacheHit { diff --git a/auth/options.go b/auth/options.go index 68fe5240d..8fc2f2bff 100644 --- a/auth/options.go +++ b/auth/options.go @@ -41,10 +41,30 @@ type AuthOptions struct { // providers. ProviderOptions ProviderOptions - // CacheKey is the key to use for caching the authentication credentials. - // Consumers must make sure to call `InitCache()` in order for caching to - // be enabled. - CacheKey string + // CacheOptions specifies the options to configure caching behavior of the + // authentication credentials. + CacheOptions CacheOptions +} + +// GetCache returns the cache to use for fetching/storing authentication +// credentials. +func (a *AuthOptions) GetCache() Store { + if a.CacheOptions.Cache != nil { + return a.CacheOptions.Cache + } + return GetCache() +} + +// CacheOptions contains options to configure the caching behavior of the +// authentication credentials. +type CacheOptions struct { + // Key is the key to use for caching the authentication credentials. + Key string + + // Cache is the Store to use for caching the authentication credentials. + // If specified, then the global cache specified through `auth.InitCache()` + // is ignored and the credentials are cached in this Store instead. + Cache Store } // ProviderOptions contains options to configure various authentication diff --git a/auth/registry/authenticator.go b/auth/registry/authenticator.go index 7ce6b6281..5c1057985 100644 --- a/auth/registry/authenticator.go +++ b/auth/registry/authenticator.go @@ -29,17 +29,23 @@ import ( // GetAuthenticator returns an authenticator that can provide credentials to // access the provided registry. -// If caching is enabled and authOpts.CacheKey is not blank, the authentication -// config is cached according to the ttl advertised by the registry provider. +// The authentication credentials will be cached if `authOpts.CacheOptions.Key` +// is not blank and caching is enabled. Caching can be enabled by either calling +// `auth.InitCache()` or specifying a cache via `authOpts.CacheOptions.Cache`. +// The credentials are cached according to the ttl advertised by the registry +// provider. func GetAuthenticator(ctx context.Context, registry string, provider string, authOpts *auth.AuthOptions) (authn.Authenticator, error) { var authConfig authn.AuthConfig - cache := auth.GetCache() - if cache != nil && authOpts != nil && authOpts.CacheKey != "" { - val, found := cache.Get(authOpts.CacheKey) - if found { - authConfig = val.(authn.AuthConfig) - return authn.FromConfig(authConfig), nil + var cache auth.Store + if authOpts != nil { + cache = authOpts.GetCache() + if cache != nil && authOpts.CacheOptions.Key != "" { + val, found := cache.Get(authOpts.CacheOptions.Key) + if found { + authConfig = val.(authn.AuthConfig) + return authn.FromConfig(authConfig), nil + } } } @@ -79,8 +85,8 @@ func GetAuthenticator(ctx context.Context, registry string, provider string, aut return nil, err } - if cache != nil && authOpts != nil && authOpts.CacheKey != "" { - if err := cache.Set(authOpts.CacheKey, authConfig, expiresIn); err != nil { + if cache != nil && authOpts != nil && authOpts.CacheOptions.Key != "" { + if err := cache.Set(authOpts.CacheOptions.Key, authConfig, expiresIn); err != nil { return nil, err } } diff --git a/auth/registry/authenticator_test.go b/auth/registry/authenticator_test.go index e46ea8bfb..12821a8e4 100644 --- a/auth/registry/authenticator_test.go +++ b/auth/registry/authenticator_test.go @@ -54,9 +54,8 @@ func TestGetAuthenticator(t *testing.T) { tokenStr, err := token.SignedString(pk) g.Expect(err).ToNot(HaveOccurred()) - var s auth.Store - s = testutils.NewDummyCache() - auth.InitCache(s) + auth.InitCache(testutils.NewDummyCache()) + customCache := testutils.NewDummyCache() tests := []struct { name string @@ -72,7 +71,9 @@ func TestGetAuthenticator(t *testing.T) { name: "get authenticator from gcp", provider: auth.ProviderGCP, authOpts: &auth.AuthOptions{ - CacheKey: "gcp-123", + CacheOptions: auth.CacheOptions{ + Key: "gcp-123", + }, }, responseBody: `{ "access_token": "access-token", @@ -94,10 +95,55 @@ func TestGetAuthenticator(t *testing.T) { }, }, { - name: "get authenticator from cache", + name: "get authenticator from global cache", provider: auth.ProviderGCP, authOpts: &auth.AuthOptions{ - CacheKey: "gcp-123", + CacheOptions: auth.CacheOptions{ + Key: "gcp-123", + }, + }, + expectCacheHit: true, + wantAuthConfig: &authn.AuthConfig{ + Username: gcp.DefaultGARUsername, + Password: "access-token", + }, + }, + { + name: "get authenticator from gcp with local cache", + provider: auth.ProviderGCP, + authOpts: &auth.AuthOptions{ + CacheOptions: auth.CacheOptions{ + Key: "gcp-local-123", + Cache: customCache, + }, + }, + responseBody: `{ + "access_token": "access-token", + "expires_in": 10, + "token_type": "Bearer" +}`, + beforeFunc: func(authOpts *auth.AuthOptions, serverURL string, registry *string) { + authOpts.ProviderOptions.GcpOpts = []gcp.ProviderOptFunc{gcp.WithTokenURL(serverURL), gcp.WithEmailURL(serverURL)} + }, + wantAuthConfig: &authn.AuthConfig{ + Username: gcp.DefaultGARUsername, + Password: "access-token", + }, + afterFunc: func(t *WithT, cache auth.Store, authConfig authn.AuthConfig) { + val, ok := cache.Get("gcp-local-123") + t.Expect(ok).To(BeTrue()) + ac := val.(authn.AuthConfig) + t.Expect(ac).To(Equal(authConfig)) + }, + }, + { + name: "get authenticator from global cache", + provider: auth.ProviderGCP, + authOpts: &auth.AuthOptions{ + CacheOptions: auth.CacheOptions{ + Key: "gcp-local-123", + Cache: customCache, + }, }, expectCacheHit: true, wantAuthConfig: &authn.AuthConfig{ @@ -126,7 +172,9 @@ func TestGetAuthenticator(t *testing.T) { name: "get authenticator from aws", provider: auth.ProviderAWS, authOpts: &auth.AuthOptions{ - CacheKey: "aws-123", + CacheOptions: auth.CacheOptions{ + Key: "aws-123", + }, }, responseBody: fmt.Sprintf(`{ "authorizationData": [ @@ -161,7 +209,9 @@ func TestGetAuthenticator(t *testing.T) { name: "get authenticator from azure", provider: auth.ProviderAzure, authOpts: &auth.AuthOptions{ - CacheKey: "azure-123", + CacheOptions: auth.CacheOptions{ + Key: "azure-123", + }, }, responseBody: fmt.Sprintf(`{"refresh_token": "%s"}`, tokenStr), beforeFunc: func(authOpts *auth.AuthOptions, serverURL string, registry *string) { @@ -219,7 +269,7 @@ func TestGetAuthenticator(t *testing.T) { g.Expect(*ac).To(Equal(*tt.wantAuthConfig)) if tt.afterFunc != nil { - tt.afterFunc(g, s, *ac) + tt.afterFunc(g, tt.authOpts.GetCache(), *ac) } if tt.expectCacheHit { g.Expect(count).To(Equal(0))