diff --git a/cmd/copy.go b/cmd/copy.go index ffe19ce90..bad30d0bd 100644 --- a/cmd/copy.go +++ b/cmd/copy.go @@ -1305,7 +1305,8 @@ func (cca *CookedCopyCmdArgs) processRedirectionDownload(blobResource common.Res } // step 1: create client options - options := &blockblob.ClientOptions{ClientOptions: createClientOptions(azcopyScanningLogger, nil)} + // note: dstCred is nil, as we could not reauth effectively because stdout is a pipe. + options := &blockblob.ClientOptions{ClientOptions: createClientOptions(azcopyScanningLogger, nil, nil)} // step 2: parse source url u, err := blobResource.FullURL() @@ -1381,8 +1382,15 @@ func (cca *CookedCopyCmdArgs) processRedirectionUpload(blobResource common.Resou return fmt.Errorf("fatal: cannot find auth on destination blob URL: %s", err.Error()) } + var reauthTok *common.ScopedAuthenticator + if at, ok := credInfo.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok { + // This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands. + reauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken())) + } + // step 0: initialize pipeline - options := &blockblob.ClientOptions{ClientOptions: createClientOptions(common.AzcopyCurrentJobLogger, nil)} + // Reauthentication is theoretically possible here, since stdin is blocked. + options := &blockblob.ClientOptions{ClientOptions: createClientOptions(common.AzcopyCurrentJobLogger, nil, reauthTok)} // step 1: parse destination url u, err := blobResource.FullURL() @@ -1564,7 +1572,18 @@ func (cca *CookedCopyCmdArgs) processCopyJobPartOrders() (err error) { }, } - options := createClientOptions(common.AzcopyCurrentJobLogger, nil) + srcCredInfo, err := cca.getSrcCredential(ctx, &jobPartOrder) + if err != nil { + return err + } + + var srcReauth *common.ScopedAuthenticator + if at, ok := srcCredInfo.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok { + // This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands. + srcReauth = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken())) + } + + options := createClientOptions(common.AzcopyCurrentJobLogger, nil, srcReauth) var azureFileSpecificOptions any if cca.FromTo.From() == common.ELocation.File() { azureFileSpecificOptions = &common.FileClientOptions{ @@ -1572,10 +1591,6 @@ func (cca *CookedCopyCmdArgs) processCopyJobPartOrders() (err error) { } } - srcCredInfo, err := cca.getSrcCredential(ctx, &jobPartOrder) - if err != nil { - return err - } jobPartOrder.SrcServiceClient, err = common.GetServiceClientForLocation( cca.FromTo.From(), cca.Source, @@ -1595,11 +1610,17 @@ func (cca *CookedCopyCmdArgs) processCopyJobPartOrders() (err error) { } } - var srcCred *common.ScopedCredential + var dstReauthTok *common.ScopedAuthenticator + if at, ok := cca.credentialInfo.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok { + // This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands. + dstReauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken())) + } + + var srcCred *common.ScopedToken if cca.FromTo.IsS2S() && srcCredInfo.CredentialType.IsAzureOAuth() { srcCred = common.NewScopedCredential(srcCredInfo.OAuthTokenInfo.TokenCredential, srcCredInfo.CredentialType) } - options = createClientOptions(common.AzcopyCurrentJobLogger, srcCred) + options = createClientOptions(common.AzcopyCurrentJobLogger, srcCred, dstReauthTok) jobPartOrder.DstServiceClient, err = common.GetServiceClientForLocation( cca.FromTo.To(), cca.Destination, diff --git a/cmd/copyEnumeratorInit.go b/cmd/copyEnumeratorInit.go index 73c537a34..b7ef9cad7 100755 --- a/cmd/copyEnumeratorInit.go +++ b/cmd/copyEnumeratorInit.go @@ -428,7 +428,13 @@ func (cca *CookedCopyCmdArgs) createDstContainer(containerName string, dstWithSA return err } - options := createClientOptions(common.AzcopyCurrentJobLogger, nil) + var reauthTok *common.ScopedAuthenticator + if at, ok := dstCredInfo.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok { + // This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands. + reauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken())) + } + + options := createClientOptions(common.AzcopyCurrentJobLogger, nil, reauthTok) sc, err := common.GetServiceClientForLocation( cca.FromTo.To(), diff --git a/cmd/credentialUtil.go b/cmd/credentialUtil.go index e3bda9d3c..6b09c48b6 100644 --- a/cmd/credentialUtil.go +++ b/cmd/credentialUtil.go @@ -365,7 +365,7 @@ func isPublic(ctx context.Context, blobResourceURL string, cpkOptions common.Cpk MaxRetryDelay: ste.UploadMaxRetryDelay, }, policy.TelemetryOptions{ ApplicationID: common.AddUserAgentPrefix(common.UserAgent), - }, nil, ste.LogOptions{}, nil) + }, nil, ste.LogOptions{}, nil, nil) blobClient, _ := blob.NewClientWithNoCredential(bURLParts.String(), &blob.ClientOptions{ClientOptions: clientOptions}) bURLParts.BlobName = "" @@ -398,7 +398,7 @@ func mdAccountNeedsOAuth(ctx context.Context, blobResourceURL string, cpkOptions MaxRetryDelay: ste.UploadMaxRetryDelay, }, policy.TelemetryOptions{ ApplicationID: common.AddUserAgentPrefix(common.UserAgent), - }, nil, ste.LogOptions{}, nil) + }, nil, ste.LogOptions{}, nil, nil) blobClient, _ := blob.NewClientWithNoCredential(blobResourceURL, &blob.ClientOptions{ClientOptions: clientOptions}) _, err := blobClient.GetProperties(ctx, &blob.GetPropertiesOptions{CPKInfo: cpkOptions.GetCPKInfo()}) @@ -577,7 +577,7 @@ func getCredentialType(ctx context.Context, raw rawFromToInfo, cpkOptions common // createClientOptions creates generic client options which are required to create any // client to interact with storage service. Default options are modified to suit azcopy. // srcCred is required in cases where source is authenticated via oAuth for S2S transfers -func createClientOptions(logger common.ILoggerResetable, srcCred *common.ScopedCredential) azcore.ClientOptions { +func createClientOptions(logger common.ILoggerResetable, srcCred *common.ScopedToken, reauthCred *common.ScopedAuthenticator) azcore.ClientOptions { logOptions := ste.LogOptions{} if logger != nil { @@ -592,7 +592,7 @@ func createClientOptions(logger common.ILoggerResetable, srcCred *common.ScopedC MaxRetryDelay: ste.UploadMaxRetryDelay, }, policy.TelemetryOptions{ ApplicationID: common.AddUserAgentPrefix(common.UserAgent), - }, ste.NewAzcopyHTTPClient(frontEndMaxIdleConnectionsPerHost), logOptions, srcCred) + }, ste.NewAzcopyHTTPClient(frontEndMaxIdleConnectionsPerHost), logOptions, srcCred, reauthCred) } const frontEndMaxIdleConnectionsPerHost = http.DefaultMaxIdleConnsPerHost diff --git a/cmd/jobsResume.go b/cmd/jobsResume.go index 5d8b8acc7..91afae14c 100644 --- a/cmd/jobsResume.go +++ b/cmd/jobsResume.go @@ -294,13 +294,19 @@ func (rca resumeCmdArgs) getSourceAndDestinationServiceClients( } } - options := createClientOptions(common.AzcopyCurrentJobLogger, nil) + var reauthTok *common.ScopedAuthenticator + if at, ok := tc.(common.AuthenticateToken); ok { // We don't need two different tokens here since it gets passed in just the same either way. + // This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands. + reauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken())) + } jobID, err := common.ParseJobID(rca.jobID) if err != nil { // Error for invalid JobId format return nil, nil, fmt.Errorf("error parsing the jobId %s. Failed with error %s", rca.jobID, err.Error()) } + // But we don't want to supply a reauth token if we're not using OAuth. That could cause problems if say, a SAS is invalid. + options := createClientOptions(common.AzcopyCurrentJobLogger, nil, common.Iff(srcCredType.IsAzureOAuth(), reauthTok, nil)) var getJobDetailsResponse common.GetJobDetailsResponse // Get job details from the STE Rpc(common.ERpcCmd.GetJobDetails(), @@ -321,11 +327,11 @@ func (rca resumeCmdArgs) getSourceAndDestinationServiceClients( return nil, nil, err } - var srcCred *common.ScopedCredential + var srcCred *common.ScopedToken if fromTo.IsS2S() && srcCredType.IsAzureOAuth() { srcCred = common.NewScopedCredential(tc, srcCredType) } - options = createClientOptions(common.AzcopyCurrentJobLogger, srcCred) + options = createClientOptions(common.AzcopyCurrentJobLogger, srcCred, common.Iff(dstCredType.IsAzureOAuth(), reauthTok, nil)) var fileClientOptions any if fromTo.To() == common.ELocation.File() { fileClientOptions = &common.FileClientOptions{ diff --git a/cmd/make.go b/cmd/make.go index 13c900c20..4214461c1 100644 --- a/cmd/make.go +++ b/cmd/make.go @@ -90,8 +90,14 @@ func (cookedArgs cookedMakeCmdArgs) process() (err error) { return err } + var reauthTok *common.ScopedAuthenticator + if at, ok := credentialInfo.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok { // We don't need two different tokens here since it gets passed in just the same either way. + // This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands. + reauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken())) + } + // Note : trailing dot is only applicable to file operations anyway, so setting this to false - options := createClientOptions(common.AzcopyCurrentJobLogger, nil) + options := createClientOptions(common.AzcopyCurrentJobLogger, nil, reauthTok) resourceURL := cookedArgs.resourceURL.String() cred := credentialInfo.OAuthTokenInfo.TokenCredential diff --git a/cmd/removeEnumerator.go b/cmd/removeEnumerator.go index f1f1c4339..1f3259a7a 100755 --- a/cmd/removeEnumerator.go +++ b/cmd/removeEnumerator.go @@ -90,7 +90,14 @@ func newRemoveEnumerator(cca *CookedCopyCmdArgs) (enumerator *CopyEnumerator, er if !from.SupportsTrailingDot() { cca.trailingDot = common.ETrailingDotOption.Disable() } - options := createClientOptions(common.AzcopyCurrentJobLogger, nil) + + var reauthTok *common.ScopedAuthenticator + if at, ok := cca.credentialInfo.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok { // We don't need two different tokens here since it gets passed in just the same either way. + // This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands. + reauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken())) + } + + options := createClientOptions(common.AzcopyCurrentJobLogger, nil, reauthTok) var fileClientOptions any if cca.FromTo.From() == common.ELocation.File() { fileClientOptions = &common.FileClientOptions{AllowTrailingDot: cca.trailingDot.IsEnabled()} @@ -142,7 +149,13 @@ func newRemoveEnumerator(cca *CookedCopyCmdArgs) (enumerator *CopyEnumerator, er func removeBfsResources(cca *CookedCopyCmdArgs) (err error) { ctx := context.WithValue(context.Background(), ste.ServiceAPIVersionOverride, ste.DefaultServiceApiVersion) sourceURL, _ := cca.Source.String() - options := createClientOptions(common.AzcopyCurrentJobLogger, nil) + var reauthTok *common.ScopedAuthenticator + if at, ok := cca.credentialInfo.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok { // We don't need two different tokens here since it gets passed in just the same either way. + // This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands. + reauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken())) + } + + options := createClientOptions(common.AzcopyCurrentJobLogger, nil, reauthTok) targetServiceClient, err := common.GetServiceClientForLocation(cca.FromTo.From(), cca.Source, cca.credentialInfo.CredentialType, cca.credentialInfo.OAuthTokenInfo.TokenCredential, &options, nil) if err != nil { diff --git a/cmd/root.go b/cmd/root.go index f6ddf790a..07257a66e 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -311,7 +311,7 @@ func beginDetectNewVersion() chan struct{} { PrintOlderVersion(*cachedVersion, *localVersion) } else { // step 2: initialize pipeline - options := createClientOptions(nil, nil) + options := createClientOptions(nil, nil, nil) // step 3: start download blobClient, err := blob.NewClientWithNoCredential(versionMetadataUrl, &blob.ClientOptions{ClientOptions: options}) diff --git a/cmd/setPropertiesEnumerator.go b/cmd/setPropertiesEnumerator.go index 10c3a5b4a..ac5eecf60 100755 --- a/cmd/setPropertiesEnumerator.go +++ b/cmd/setPropertiesEnumerator.go @@ -72,7 +72,13 @@ func setPropertiesEnumerator(cca *CookedCopyCmdArgs) (enumerator *CopyEnumerator jobsAdmin.JobsAdmin.LogToJobLog(message, common.LogInfo) } - options := createClientOptions(common.AzcopyCurrentJobLogger, nil) + var reauthTok *common.ScopedAuthenticator + if at, ok := cca.credentialInfo.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok { // We don't need two different tokens here since it gets passed in just the same either way. + // This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands. + reauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken())) + } + + options := createClientOptions(common.AzcopyCurrentJobLogger, nil, reauthTok) var fileClientOptions any if cca.FromTo.From() == common.ELocation.File() { fileClientOptions = &common.FileClientOptions{AllowTrailingDot: cca.trailingDot.IsEnabled()} diff --git a/cmd/syncEnumerator.go b/cmd/syncEnumerator.go index 71ae2e8bf..a02bf962e 100644 --- a/cmd/syncEnumerator.go +++ b/cmd/syncEnumerator.go @@ -179,7 +179,13 @@ func (cca *cookedSyncCmdArgs) initEnumerator(ctx context.Context) (enumerator *s }, } - options := createClientOptions(common.AzcopyCurrentJobLogger, nil) + var srcReauthTok *common.ScopedAuthenticator + if at, ok := srcCredInfo.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok { + // This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands. + srcReauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken())) + } + + options := createClientOptions(common.AzcopyCurrentJobLogger, nil, srcReauthTok) // Create Source Client. var azureFileSpecificOptions any @@ -209,12 +215,18 @@ func (cca *cookedSyncCmdArgs) initEnumerator(ctx context.Context) (enumerator *s } } - var srcTokenCred *common.ScopedCredential + var dstReauthTok *common.ScopedAuthenticator + if at, ok := srcCredInfo.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok { + // This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands. + dstReauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken())) + } + + var srcTokenCred *common.ScopedToken if cca.fromTo.IsS2S() && srcCredInfo.CredentialType.IsAzureOAuth() { srcTokenCred = common.NewScopedCredential(srcCredInfo.OAuthTokenInfo.TokenCredential, srcCredInfo.CredentialType) } - options = createClientOptions(common.AzcopyCurrentJobLogger, srcTokenCred) + options = createClientOptions(common.AzcopyCurrentJobLogger, srcTokenCred, dstReauthTok) copyJobTemplate.DstServiceClient, err = common.GetServiceClientForLocation( cca.fromTo.To(), cca.destination, diff --git a/cmd/versionChecker_test.go b/cmd/versionChecker_test.go index c65597de9..7ddae4f83 100644 --- a/cmd/versionChecker_test.go +++ b/cmd/versionChecker_test.go @@ -215,7 +215,7 @@ func TestCheckReleaseMetadata(t *testing.T) { a := assert.New(t) // sanity test for checking if the release metadata exists and can be downloaded - options := createClientOptions(nil, nil) + options := createClientOptions(nil, nil, nil) blobClient, err := blob.NewClientWithNoCredential(versionMetadataUrl, &blob.ClientOptions{ClientOptions: options}) a.NoError(err) diff --git a/cmd/zc_enumerator.go b/cmd/zc_enumerator.go index 66836294b..907e08250 100644 --- a/cmd/zc_enumerator.go +++ b/cmd/zc_enumerator.go @@ -381,7 +381,13 @@ func InitResourceTraverser(resource common.ResourceString, location common.Locat return output, nil } - options := createClientOptions(azcopyScanningLogger, nil) + var reauthTok *common.ScopedAuthenticator + if at, ok := credential.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok { + // This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands. + reauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken())) + } + + options := createClientOptions(azcopyScanningLogger, nil, reauthTok) switch location { case common.ELocation.Local(): diff --git a/common/oauthTokenManager.go b/common/oauthTokenManager.go index 385993456..794aa50ab 100644 --- a/common/oauthTokenManager.go +++ b/common/oauthTokenManager.go @@ -701,6 +701,11 @@ func (credInfo *OAuthTokenInfo) GetDeviceCodeCredential() (azcore.TokenCredentia Cloud: cloud.Configuration{ActiveDirectoryAuthorityHost: authorityHost.String()}, Transport: newAzcopyHTTPClient(), }, + UserPrompt: func(ctx context.Context, message azidentity.DeviceCodeMessage) error { + lcm.Info(fmt.Sprintf("Authentication is required. To sign in, open the webpage %s and enter the code %s to authenticate.", + Iff(message.VerificationURL != "", message.VerificationURL, "https://aka.ms/devicelogin"), message.UserCode)) + return nil + }, }) if err != nil { return nil, err @@ -727,6 +732,11 @@ func (credInfo *OAuthTokenInfo) GetDeviceCodeCredential() (azcore.TokenCredentia return tc, nil } +type AuthenticateToken interface { + azcore.TokenCredential + Authenticate(ctx context.Context, opts *policy.TokenRequestOptions) (azidentity.AuthenticationRecord, error) +} + func (credInfo *OAuthTokenInfo) GetTokenCredential() (azcore.TokenCredential, error) { // Token Credential is cached. if credInfo.TokenCredential != nil { diff --git a/common/output.go b/common/output.go index 910f169a5..9df415dd8 100644 --- a/common/output.go +++ b/common/output.go @@ -68,6 +68,7 @@ var EPromptType = PromptType("") type PromptType string +func (PromptType) Reauth() PromptType { return PromptType("Reauth") } func (PromptType) Cancel() PromptType { return PromptType("Cancel") } func (PromptType) Overwrite() PromptType { return PromptType("Overwrite") } func (PromptType) DeleteDestination() PromptType { return PromptType("DeleteDestination") } diff --git a/common/util.go b/common/util.go index e7708dab8..b81a95d3c 100644 --- a/common/util.go +++ b/common/util.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "net" "net/url" "strings" @@ -231,7 +232,7 @@ func GetServiceClientForLocation(loc Location, // NewScopedCredential takes in a credInfo object and returns ScopedCredential // if credentialType is either MDOAuth or oAuth. For anything else, // nil is returned -func NewScopedCredential(cred azcore.TokenCredential, credType CredentialType) *ScopedCredential { +func NewScopedCredential[T azcore.TokenCredential](cred T, credType CredentialType) *ScopedCredential[T] { var scope string if !credType.IsAzureOAuth() { return nil @@ -240,18 +241,29 @@ func NewScopedCredential(cred azcore.TokenCredential, credType CredentialType) * } else if credType == ECredentialType.OAuthToken() { scope = StorageScope } - return &ScopedCredential{cred: cred, scopes: []string{scope}} + return &ScopedCredential[T]{cred: cred, scopes: []string{scope}} } -type ScopedCredential struct { - cred azcore.TokenCredential +type ScopedCredential[T azcore.TokenCredential] struct { + cred T scopes []string } -func (s *ScopedCredential) GetToken(ctx context.Context, _ policy.TokenRequestOptions) (azcore.AccessToken, error) { +func (s *ScopedCredential[T]) GetToken(ctx context.Context, _ policy.TokenRequestOptions) (azcore.AccessToken, error) { return s.cred.GetToken(ctx, policy.TokenRequestOptions{Scopes: s.scopes, EnableCAE: true}) } +type ScopedToken = ScopedCredential[azcore.TokenCredential] +type ScopedAuthenticator ScopedCredential[AuthenticateToken] + +func (s *ScopedAuthenticator) GetToken(ctx context.Context, _ policy.TokenRequestOptions) (azcore.AccessToken, error) { + return s.cred.GetToken(ctx, policy.TokenRequestOptions{Scopes: s.scopes, EnableCAE: true}) +} + +func (s *ScopedAuthenticator) Authenticate(ctx context.Context, _ *policy.TokenRequestOptions) (azidentity.AuthenticationRecord, error) { + return s.cred.Authenticate(ctx, &policy.TokenRequestOptions{Scopes: s.scopes, EnableCAE: true}) +} + type ServiceClient struct { fsc *fileservice.Client bsc *blobservice.Client diff --git a/ste/destReauthPolicy.go b/ste/destReauthPolicy.go new file mode 100644 index 000000000..f29038653 --- /dev/null +++ b/ste/destReauthPolicy.go @@ -0,0 +1,124 @@ +package ste + +import ( + "context" + "errors" + "fmt" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/bloberror" + "github.com/Azure/azure-storage-azcopy/v10/common" + "net/http" + "sync" + "time" +) + +/* +RESPONSE Status: 401 Server failed to authenticate the request. Please refer to the information in the www-authenticate header. +Www-Authenticate: Bearer authorization_uri=https://login.microsoftonline.com/72f988bf-86f1-41af-91ab-2d7cd011db47/oauth2/authorize resource_id=https://storage.azure.com +X-Ms-Error-Code: InvalidAuthenticationInfo + InvalidAuthenticationInfoServer failed to authenticate the request. Please refer to the information in the www-authenticate header. Lifetime validation failed. The token is expired. +*/ + +type destReauthPolicy struct { + cred *common.ScopedAuthenticator +} + +var reauthLock *sync.Cond = sync.NewCond(&sync.Mutex{}) + +func NewDestReauthPolicy(cred *common.ScopedAuthenticator) policy.Policy { + return &destReauthPolicy{cred} +} + +type destReauthDebug string + +var ( + destReauthDebugExecuted destReauthDebug = "executed" + destReauthDebugNoPrompt destReauthDebug = "destReauthNoPrompt" + destReauthDebugCause destReauthDebug = "destReauthCause" + destReauthDebugCauseAuthenticationRequired destReauthDebug = "AuthenticationRequiredError" + destReauthDebugCauseInvalidAuthenticationInfo destReauthDebug = "InvalidAuthenticationInfoError" +) + +func (d *destReauthPolicy) Do(req *policy.Request) (*http.Response, error) { +retry: + ctx := req.Raw().Context() + debugCtx := context.WithValue(ctx, destReauthDebugExecuted, true) + + clone := req.Clone(ctx) + resp, err := clone.Next() // Initially attempt the request. + + if err != nil || resp.StatusCode != http.StatusOK { // But, if we get back an error... + var authReq *azidentity.AuthenticationRequiredError + var respErr = &azcore.ResponseError{} + + reauth := false + + switch { // Is it an error we can resolve by re-authing? + case errors.As(err, &authReq): + reauth = true + debugCtx = context.WithValue(debugCtx, destReauthDebugCause, destReauthDebugCauseAuthenticationRequired) + case resp.StatusCode == http.StatusUnauthorized: + errors.As(runtime.NewResponseError(resp), &respErr) + reauth = err == nil && + bloberror.HasCode(respErr, bloberror.InvalidAuthenticationInfo) && + len(respErr.RawResponse.Header.Values("WWW-Authenticate")) != 0 + if reauth { + debugCtx = context.WithValue(debugCtx, destReauthDebugCause, destReauthDebugCauseInvalidAuthenticationInfo) + } + } + + if reauth { // If it is, pull the lock if we can, reauth + m := reauthLock.L.(*sync.Mutex) + + if m.TryLock() { // Fetch the lock and try until we get auth. + for { + if ctx.Value(destReauthDebugNoPrompt) == nil { + _ = common.GetLifecycleMgr().Prompt("Authentication is required to continue the job. Reauthorize and continue?", common.PromptDetails{ + PromptType: common.EPromptType.Reauth(), + ResponseOptions: []common.ResponseOption{ + common.EResponseOption.Yes(), + }, + }) + } + + _, err = d.cred.Authenticate(debugCtx, &policy.TokenRequestOptions{ + Scopes: []string{}, + }) + + // I (Adele Reed) was initially worried about every case + // Thinking about it further, the worst case is that the job ends automatically, or when the user asks it to end. + // To avoid having to handle every error, we'll catch the cancel case as a way to exit the routine, but otherwise + // we will let it happen, and just retry. + if err == nil { + break + } else { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + select { + case <-ctx.Done(): // If it was us, exit like asked. + return nil, err // If it was us, that's legitimately important. + default: // If it was them, we don't care. + } + } else { + common.GetLifecycleMgr().Info(fmt.Sprintf("Authentication failed, awaiting input to continue: %s", err)) + } + + time.Sleep(time.Second * 5) + } + } + + m.Unlock() + reauthLock.Broadcast() + } else { // Otherwise, wait for a signal that we can try again. + reauthLock.Wait() + } + + // Try the request once more + goto retry + } // If it wasn't, we won't retry, and we'll simply return the error. + } + + return resp, err +} diff --git a/ste/mgr-JobPartMgr.go b/ste/mgr-JobPartMgr.go index 62bf9f686..f7a487291 100644 --- a/ste/mgr-JobPartMgr.go +++ b/ste/mgr-JobPartMgr.go @@ -129,12 +129,15 @@ func (d *dialRateLimiter) DialContext(ctx context.Context, network, address stri return d.dialer.DialContext(ctx, network, address) } -func NewClientOptions(retry policy.RetryOptions, telemetry policy.TelemetryOptions, transport policy.Transporter, log LogOptions, srcCred *common.ScopedCredential) azcore.ClientOptions { +func NewClientOptions(retry policy.RetryOptions, telemetry policy.TelemetryOptions, transport policy.Transporter, log LogOptions, srcCred *common.ScopedToken, dstCred *common.ScopedAuthenticator) azcore.ClientOptions { // Pipeline will look like // [includeResponsePolicy, newAPIVersionPolicy (ignored), NewTelemetryPolicy, perCall, NewRetryPolicy, perRetry, NewLogPolicy, httpHeaderPolicy, bodyDownloadPolicy] perCallPolicies := []policy.Policy{azruntime.NewRequestIDPolicy(), NewVersionPolicy(), newFileUploadRangeFromURLFixPolicy()} // TODO : Default logging policy is not equivalent to old one. tracing HTTP request perRetryPolicies := []policy.Policy{newRetryNotificationPolicy(), newLogPolicy(log), newStatsPolicy()} + if dstCred != nil { + perCallPolicies = append(perRetryPolicies, NewDestReauthPolicy(dstCred)) + } if srcCred != nil { perRetryPolicies = append(perRetryPolicies, NewSourceAuthPolicy(srcCred)) } diff --git a/ste/testJobPartTransferManager_test.go b/ste/testJobPartTransferManager_test.go index 4255a6444..0961a8643 100644 --- a/ste/testJobPartTransferManager_test.go +++ b/ste/testJobPartTransferManager_test.go @@ -292,7 +292,7 @@ func (t *testJobPartTransferManager) S2SSourceClientOptions() azcore.ClientOptio httpClient := NewAzcopyHTTPClient(4) - return NewClientOptions(retryOptions, telemetryOptions, httpClient, LogOptions{}, nil) + return NewClientOptions(retryOptions, telemetryOptions, httpClient, LogOptions{}, nil, nil) } func (t *testJobPartTransferManager) CredentialOpOptions() *common.CredentialOpOptions { diff --git a/ste/zt_destReauthPolicy_test.go b/ste/zt_destReauthPolicy_test.go new file mode 100644 index 000000000..c35b31275 --- /dev/null +++ b/ste/zt_destReauthPolicy_test.go @@ -0,0 +1,137 @@ +package ste + +import ( + "context" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + blobservice "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/service" + "github.com/Azure/azure-storage-azcopy/v10/common" + "github.com/stretchr/testify/assert" + "io" + "net/http" + "strings" + "testing" + "time" +) + +const ( + authRequiredResp = ` + + InvalidAuthenticationInfo + placeholder +` + accountPropsResp = ` + +` +) + +type ReauthTransporter struct { + RequireAuth bool +} + +func (r *ReauthTransporter) Do(req *http.Request) (*http.Response, error) { + if r.RequireAuth { // Format this as a retry blob error + h := http.Header{} + h.Add("WWW-Authenticate", "Bearer authorization_uri=https://login.microsoftonline.com/c1cacfe1-4dd7-4d62-b8c5-5b6cf62d10f9/oauth2/authorize resource_id=https://storage.azure.com") + return &http.Response{ + StatusCode: http.StatusUnauthorized, + Status: http.StatusText(http.StatusUnauthorized), + ContentLength: int64(len(authRequiredResp)), + Body: io.NopCloser(strings.NewReader(authRequiredResp)), + Request: req, + Header: h, + }, nil + } + + return &http.Response{ + StatusCode: http.StatusOK, + Status: http.StatusText(http.StatusOK), + ContentLength: int64(len(accountPropsResp)), + Body: io.NopCloser(strings.NewReader(accountPropsResp)), + Request: req, + }, nil +} + +type ReauthTestCred struct { + // ImmediateReauth fires off an azidentity.AuthenticationRequiredError in GetToken + ImmediateReauth bool + + ReauthCallback func(ctx context.Context) +} + +func (r *ReauthTestCred) GetToken(ctx context.Context, options policy.TokenRequestOptions) (azcore.AccessToken, error) { + if r.ImmediateReauth { + return azcore.AccessToken{}, &azidentity.AuthenticationRequiredError{} + } + + return azcore.AccessToken{Token: "foobar", ExpiresOn: time.Now().Add(time.Hour * 24)}, nil +} + +func (r *ReauthTestCred) Authenticate(ctx context.Context, opts *policy.TokenRequestOptions) (azidentity.AuthenticationRecord, error) { + if r.ReauthCallback != nil { + r.ReauthCallback(ctx) + } + + r.ImmediateReauth = false + + return azidentity.AuthenticationRecord{}, nil +} + +// This is not an end-to-end test. But it is an instantaneous validation of the logic. +func TestDestReauthPolicy(t *testing.T) { + rootctx := context.WithValue(context.Background(), destReauthDebugNoPrompt, true) + ctx, cancel := context.WithCancel(rootctx) + + reauthed := false + cred := &ReauthTestCred{ + ReauthCallback: func(ctx context.Context) { + reauthed = true + assert.Equal(t, ctx.Value(destReauthDebugExecuted), true, "Expected reauth to occur via the policy") + assert.Equal(t, ctx.Value(destReauthDebugCause), destReauthDebugCauseAuthenticationRequired, "Expected reauth to occur via the AuthenticationRequired mechanism") + cancel() + }, + } + + transport := &ReauthTransporter{} + + opts := NewClientOptions( + policy.RetryOptions{}, + policy.TelemetryOptions{}, + transport, + LogOptions{}, + nil, (*common.ScopedAuthenticator)(common.NewScopedCredential[common.AuthenticateToken](cred, common.ECredentialType.OAuthToken())), + ) + + c, err := blobservice.NewClient("https://foobar.blob.core.windows.net/", cred, &blobservice.ClientOptions{ClientOptions: opts}) + assert.NoError(t, err, "Failed to create service client") + + // Initially, fire off a request that will get slapped with an AuthenticationRequired. + cred.ImmediateReauth = true + _, err = c.GetProperties(ctx, nil) + assert.Equal(t, reauthed, true, "Expected reauthentication attempt in request") + + // Reset the context + ctx, cancel = context.WithCancel(rootctx) + ctx = context.WithValue(ctx, destReauthDebugNoPrompt, true) + reauthed = false + + // =========== InvalidAuthenticationInfo ============ + + // Set the cred & request to require a reauth on round trip, rather than up front, triggering the alternative activation method + cred.ImmediateReauth = false + transport.RequireAuth = true + + // reset callback + cred.ReauthCallback = func(ctx context.Context) { + reauthed = true + assert.Equal(t, ctx.Value(destReauthDebugExecuted), true, "Expected reauth to occur via the policy") + assert.Equal(t, ctx.Value(destReauthDebugCause), destReauthDebugCauseInvalidAuthenticationInfo, "Expected reauth to occur via the InvalidAuthenticationInfo mechanism") + cancel() + transport.RequireAuth = false + } + + // Initially, fire off a request that will get slapped with an InvalidAuthenticationMethod + _, err = c.GetProperties(ctx, nil) + assert.Equal(t, reauthed, true, "Expected reauthentication attempt in request") +}