Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow AzCopy to receive OAuth tokens via GRPC for internal integrations #2778

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ var azcopyScanningLogger common.ILoggerResetable
var azcopyCurrentJobID common.JobID
var azcopySkipVersionCheck bool
var retryStatusCodes string
var grpcServerPort string

type jobLoggerInfo struct {
jobID common.JobID
Expand Down Expand Up @@ -189,6 +190,12 @@ var rootCmd = &cobra.Command{
beginDetectNewVersion()
}

if common.GrpcShim.Available() {
if any(common.GrpcShim).(common.GrpcCtl).SetupGrpc(grpcServerPort, common.AzcopyCurrentJobLogger) != nil {
return err
}
}

if debugSkipFiles != "" {
for _, v := range strings.Split(debugSkipFiles, ";") {
if strings.HasPrefix(v, "/") {
Expand Down Expand Up @@ -256,9 +263,17 @@ func init() {
rootCmd.PersistentFlags().BoolVar(&azcopyAwaitAllowOpenFiles, "await-open", false, "Used when debugging, to tell AzCopy to await `open` on stdin, after scanning but before opening the first file. Assists with testing cases around file modifications between scanning and usage")
rootCmd.PersistentFlags().StringVar(&debugSkipFiles, "debug-skip-files", "", "Used when debugging, to tell AzCopy to cancel the job midway. List of relative paths to skip in the STE.")

// special remote control flag, only available if the build enabled it.
if common.GrpcShim.Available() {
rootCmd.PersistentFlags().StringVar(&grpcServerPort, "grpc-server-addr", "", "Used in specific scenarios; defaults to disabled. If set, listens on the requested port (e.g. 127.0.0.1:9879). Protocol spec is in grpcctl/internal.")
}

// reserved for partner teams
_ = rootCmd.PersistentFlags().MarkHidden("cancel-from-stdin")

// currently for use in the ev2 extension
_ = rootCmd.PersistentFlags().MarkHidden("enable-grpc-server")

// special flags to be used in case of unexpected service errors.
rootCmd.PersistentFlags().StringVar(&retryStatusCodes, "retry-status-codes", "", "Comma-separated list of HTTP status codes to retry on. (default '408;429;500;502;503;504')")
_ = rootCmd.PersistentFlags().MarkHidden("retry-status-codes")
Expand Down
1 change: 1 addition & 0 deletions common/environment.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ func (AutoLoginType) MSI() AutoLoginType { return AutoLoginType(2) }
func (AutoLoginType) AzCLI() AutoLoginType { return AutoLoginType(3) }
func (AutoLoginType) PsCred() AutoLoginType { return AutoLoginType(4) }
func (AutoLoginType) Workload() AutoLoginType { return AutoLoginType(5) }
func (AutoLoginType) GRPC() AutoLoginType { return AutoLoginType(254) } // Ev2 Extension/FRP integration only. Receives fresh OAuth tokens via GRPC.
func (AutoLoginType) TokenStore() AutoLoginType { return AutoLoginType(255) } // Storage Explorer internal integration only. Do not add this to ValidAutoLoginTypes.

func (d AutoLoginType) String() string {
Expand Down
55 changes: 55 additions & 0 deletions common/grpc_setup.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
//go:build grpc
// +build grpc

package common

import (
"github.com/Azure/azure-storage-azcopy/v10/grpcctl"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
"golang.stackrox.io/grpc-http1/server"
"net/http"
)

func (grpcCtlImpl) SetupGrpc(addr string, logger ILoggerResetable) error {
if addr != "" {
// JobLog is a function, rather than just a reference, to avoid a dependency loop. It's gross, I know.
grpcctl.JobLog = func(s string) {
logger.Log(LogInfo, s)
}

// Spin off the HTTP server
go func() {
// HTTP/1 needs support
srv := &http.Server{
Addr: addr,
}

// But we must also support HTTP/2 for "modern" clients.
var h2srv http2.Server

// The downgrade handler will allow clients to request grpc-web support, removing trailers, etc. for platforms like .NET Framework 4.7.2.
srv.Handler = h2c.NewHandler(
server.CreateDowngradingHandler(
grpcctl.GlobalGRPCServer,
http.NotFoundHandler(), // No fallback handler is needed.
server.PreferGRPCWeb(true)), // If grpc-web is requested, grpc-web we'll give.
&h2srv)

// Start listening.
err := srv.ListenAndServe()
if err != nil {
panic("grpcfailed: " + err.Error())
}
}()
}

// Historically, this could return an error. it does not anymore.
return nil
}

func (grpcCtlImpl) SetupOAuthSubscription(updateFunc func(token *OAuthTokenUpdate)) {
grpcctl.Subscribe(grpcctl.GlobalServer, func(i *grpcctl.OAuthTokenUpdate) {
updateFunc((*OAuthTokenUpdate)(i))
})
}
28 changes: 28 additions & 0 deletions common/grpc_shim.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package common

import "time"

/*
grpc_shim.go implements a shim that allows for GRPC functionality to reasonably disappear, removing all package references to grpcctl, and grpc in general.
*/

type GrpcCtl interface {
SetupGrpc(string, ILoggerResetable) error
SetupOAuthSubscription(update func(*OAuthTokenUpdate))
}

type grpcCtlImpl struct{}

var GrpcShim grpcCtlImpl

func (g grpcCtlImpl) Available() bool {
_, ok := (any(g)).(GrpcCtl)
return ok
}

type OAuthTokenUpdate struct {
Token string
Live time.Time
Expiry time.Time
Wiggle time.Duration
}
142 changes: 139 additions & 3 deletions common/oauthTokenManager.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func newAzcopyHTTPClient() *http.Client {
Timeout: 10 * time.Second,
KeepAlive: 10 * time.Second,
DualStack: true,
}).Dial, /*Context*/
}).Dial, /*Context*/
MaxIdleConns: 0, // No limit
MaxIdleConnsPerHost: 1000,
IdleConnTimeout: 180 * time.Second,
Expand Down Expand Up @@ -398,6 +398,17 @@ func (uotm *UserOAuthTokenManager) getTokenInfoFromEnvVar(ctx context.Context) (
return nil, fmt.Errorf("get token from environment variable failed to unmarshal token, %v", err)
}

// Seed the grpc state with what we retrieved from the env var
if tokenInfo.LoginType == EAutoLoginType.GRPC() {
g := globalRPCOAuthTokenState
g.Mutex.L.Lock()
g.Token = tokenInfo.AccessToken
g.Live = time.Now()
g.Expiry = tokenInfo.Expires()
g.Mutex.L.Unlock()
g.Mutex.Broadcast()
}

if tokenInfo.LoginType != EAutoLoginType.TokenStore() {
refreshedToken, err := tokenInfo.Refresh(ctx)
if err != nil {
Expand Down Expand Up @@ -524,15 +535,14 @@ type TokenStoreCredential struct {
// we do not make repeated GetToken calls.
// This is a temporary fix for issue where we would request a
// new token from Stg Exp even while they've not yet populated the
// tokenstore.
// tokenstore.
//
// This is okay because we use same credential on both source and
// destination. If we move to a case where the credentials are
// different, this should be removed.
//
// We should move to a method where the token is always read from
// tokenstore, and azcopy is invoked after tokenstore is populated.
//
var globalTokenStoreCredential *TokenStoreCredential
var globalTsc sync.Once

Expand Down Expand Up @@ -739,6 +749,12 @@ func (credInfo *OAuthTokenInfo) GetDeviceCodeCredential() (azcore.TokenCredentia
return tc, nil
}

func (credInfo *OAuthTokenInfo) GetGRPCOAuthCredential() (azcore.TokenCredential, error) {
tc := globalRPCOAuthTokenState
credInfo.TokenCredential = tc
return tc, nil
}

func (credInfo *OAuthTokenInfo) GetTokenCredential() (azcore.TokenCredential, error) {
// Token Credential is cached.
if credInfo.TokenCredential != nil {
Expand All @@ -750,6 +766,8 @@ func (credInfo *OAuthTokenInfo) GetTokenCredential() (azcore.TokenCredential, er
}

switch credInfo.LoginType {
case EAutoLoginType.GRPC():
return credInfo.GetGRPCOAuthCredential()
case EAutoLoginType.MSI():
return credInfo.GetManagedIdentityCredential()
case EAutoLoginType.SPN():
Expand Down Expand Up @@ -785,6 +803,124 @@ func jsonToTokenInfo(b []byte) (*OAuthTokenInfo, error) {

// ====================================================================================

var globalGRPCOAuthTokenLock = &sync.RWMutex{}

// Pass in the RLocker to the sync.Cond so we're only requesting read locks, not write locks.
var globalRPCOAuthTokenState = &GRPCOAuthToken{Mutex: sync.NewCond(globalGRPCOAuthTokenLock.RLocker())}

func init() {
vibhansa-msft marked this conversation as resolved.
Show resolved Hide resolved
if GrpcShim.Available() {
any(GrpcShim).(GrpcCtl).SetupOAuthSubscription(func(token *OAuthTokenUpdate) {
g := globalRPCOAuthTokenState
globalGRPCOAuthTokenLock.Lock() // Grab the write lock

if AzcopyCurrentJobLogger != nil {
AzcopyCurrentJobLogger.Log(LogInfo, fmt.Sprintf("Received fresh OAuth token. (invalid: %v, exp: %v, now: %v, expiry: %v)", g.Token == "", time.Now().After(g.Expiry), time.Now(), g.Expiry))
}

// Write the fresh token we've been handed
g.Token = token.Token
g.Live = token.Live
g.Expiry = token.Expiry
g.Wiggle = token.Wiggle
g.GiveUp = false // We've stopped giving up, if we've received a fresh token.

if AzcopyCurrentJobLogger != nil {
AzcopyCurrentJobLogger.Log(LogInfo, "Broadcasting new OAuth token.")
}

// Drop the lock, let "clients" know there's a new token.
globalGRPCOAuthTokenLock.Unlock()
g.Mutex.Broadcast()
})
}
}

type GRPCOAuthToken struct {
Token string
Live time.Time
Expiry time.Time
// Time in seconds before expiry we'll act like it is expired
Wiggle time.Duration

// If we don't receive an oauth token for an extended period of time, just give up and fail fast. It's too long to wait for say, 50k files to fail at their own pace.
GiveUp bool

// Using a sync.Cond, we allow "clients" to drop their lock and await a fresh token signal.
Mutex *sync.Cond
}

func (g *GRPCOAuthToken) GetToken(ctx context.Context, options policy.TokenRequestOptions) (azcore.AccessToken, error) {
g.Mutex.L.Lock() // Grab the read lock
defer g.Mutex.L.Unlock() // Defer the drop lock, so we don't need to mentally consider this lock.

retry:
if g.GiveUp { // Add an escape catch, this way new requests after we've given up also give up immediately. This reduces failure time.
return azcore.AccessToken{}, fmt.Errorf("timed out waiting for new token (GiveUp set)")
}

exp := g.Expiry.Add(-g.Wiggle)
totalDuration := g.Expiry.Sub(g.Live)

if time.Now().After(exp) || // The token is naturally expired, this should happen.
g.Token == "" || // The token is empty...
totalDuration < 0 { // The token couldn't have been valid...
// Log any potential issues with the token.
if AzcopyCurrentJobLogger != nil {
AzcopyCurrentJobLogger.Log(LogInfo, fmt.Sprintf("Token is expired or invalid (invalid: %v, exp: %v, now: %v, expiry: %v)", g.Token == "" || totalDuration < 0, time.Now().After(exp), time.Now(), exp))
}

// Begin waiting for a fresh token.
waitBegin := time.Now()
waitch := make(chan bool)

go func() {
// g.Mutex.Wait silently releases, then re-captures the (read) mutex.
// For duration,
g.Mutex.Wait()
if AzcopyCurrentJobLogger != nil {
AzcopyCurrentJobLogger.Log(LogInfo, "Released.")
}

// send wait unblock
waitch <- true
}()

// Time out eventually, so that AzCopy can exit.
select {
case <-waitch:
if AzcopyCurrentJobLogger != nil {
AzcopyCurrentJobLogger.Log(LogInfo, "Received signal from waitch")
}
close(waitch)
case <-time.After(totalDuration * 3):
g.Mutex.L.Lock() // Grab the write lock.
g.GiveUp = true // Tell everybody we're giving up.
g.Mutex.Broadcast() // Unblock our waiter, *and* everybody else, now that we know we're giving up.
g.Mutex.L.Unlock() // Drop the write lock.

<-waitch // We must wait for our waiter, because that signals that we have our original lock back. If we drop without it, we may drop somebody else's (or nobody else's)
close(waitch)

return azcore.AccessToken{}, fmt.Errorf("timed out waiting for new token (3x last live duration) (Began waiting %v, finished %v, duration %v)", waitBegin, time.Now(), totalDuration*3)
}

goto retry
}

t := g.Token
e := g.Expiry
return azcore.AccessToken{
Token: t,
// e could be zero and the result would be the same
// azcore will """refresh""" every 30s
// but we'll hand it back the same token (or whatever has been updated to)
ExpiresOn: e,
}, nil
}

// ====================================================================================

// TestOAuthInjection controls variables for OAuth testing injections
type TestOAuthInjection struct {
DoTokenRefreshInjection bool
Expand Down
7 changes: 5 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ require (
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0
github.com/Azure/go-autorest/autorest/date v0.3.0
golang.org/x/net v0.27.0
golang.stackrox.io/grpc-http1 v0.3.12
google.golang.org/grpc v1.65.0
google.golang.org/protobuf v1.34.2
)

require (
Expand All @@ -55,6 +58,7 @@ require (
github.com/go-logr/stdr v1.2.2 // indirect
github.com/golang-jwt/jwt/v4 v4.5.0 // indirect
github.com/golang-jwt/jwt/v5 v5.2.1 // indirect
github.com/golang/glog v1.2.1 // indirect
github.com/google/s2a-go v0.1.8 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
github.com/googleapis/gax-go/v2 v2.13.0 // indirect
Expand All @@ -78,9 +82,8 @@ require (
google.golang.org/genproto v0.0.0-20240723171418-e6d459c13d2a // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240723171418-e6d459c13d2a // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240723171418-e6d459c13d2a // indirect
google.golang.org/grpc v1.65.0 // indirect
google.golang.org/protobuf v1.34.2 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
nhooyr.io/websocket v1.8.11 // indirect
)

go 1.22.5
6 changes: 6 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/glog v1.2.1 h1:OptwRhECazUx5ix5TTWC3EZhsZEHWcYWY4FQHTIubm4=
github.com/golang/glog v1.2.1/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w=
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
Expand Down Expand Up @@ -256,6 +258,8 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.stackrox.io/grpc-http1 v0.3.12 h1:5Hmpu6OJN4GsoqF+orSv5rPd/R/W02rfPkX0xq3GY1I=
golang.stackrox.io/grpc-http1 v0.3.12/go.mod h1:Xl6wDiPJ7OjLWOzwxFwhThktLTNs1iDS95alVpZct/A=
google.golang.org/api v0.189.0 h1:equMo30LypAkdkLMBqfeIqtyAnlyig1JSZArl4XPwdI=
google.golang.org/api v0.189.0/go.mod h1:FLWGJKb0hb+pU2j+rJqwbnsF+ym+fQs73rbJ+KAUgy8=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
Expand Down Expand Up @@ -295,3 +299,5 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
nhooyr.io/websocket v1.8.11 h1:f/qXNc2/3DpoSZkHt1DQu6rj4zGC8JmkkLkWss0MgN0=
nhooyr.io/websocket v1.8.11/go.mod h1:rN9OFWIUwuxg4fR5tELlYC04bXYowCP9GX47ivo2l+c=
Loading
Loading