From 8ff3cc74b7fd8909bde314dfdd772a7a318bd5c6 Mon Sep 17 00:00:00 2001 From: Bani Singh <47721811+banikharbanda@users.noreply.github.com> Date: Wed, 20 Nov 2024 10:08:05 -0500 Subject: [PATCH] Separate PerRPCCreds from Oauth creds to eliminate undefined behavior (#603) --- go/pkg/balancer/BUILD.bazel | 7 +- go/pkg/credshelper/BUILD.bazel | 20 ++---- go/pkg/credshelper/credshelper.go | 89 ++++++++++++++++++-------- go/pkg/credshelper/credshelper_test.go | 13 ++-- go/pkg/flags/flags.go | 2 +- 5 files changed, 74 insertions(+), 57 deletions(-) diff --git a/go/pkg/balancer/BUILD.bazel b/go/pkg/balancer/BUILD.bazel index f43fdf804..a0c4c673c 100644 --- a/go/pkg/balancer/BUILD.bazel +++ b/go/pkg/balancer/BUILD.bazel @@ -1,14 +1,11 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("@io_bazel_rules_go//go:def.bzl", "go_library") go_library( name = "balancer", - srcs = [ - "roundrobin.go", - ], + srcs = ["roundrobin.go"], importpath = "github.com/bazelbuild/remote-apis-sdks/go/pkg/balancer", visibility = ["//visibility:public"], deps = [ "@org_golang_google_grpc//:go_default_library", ], ) - diff --git a/go/pkg/credshelper/BUILD.bazel b/go/pkg/credshelper/BUILD.bazel index 23d5d7263..c9c8cc8e1 100644 --- a/go/pkg/credshelper/BUILD.bazel +++ b/go/pkg/credshelper/BUILD.bazel @@ -2,32 +2,20 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "credshelper", - srcs = [ - "credshelper.go", - ], + srcs = ["credshelper.go"], importpath = "github.com/bazelbuild/remote-apis-sdks/go/pkg/credshelper", visibility = ["//visibility:public"], deps = [ - "//go/api/credshelper", - "//go/pkg/digest", + "//go/pkg/digest", "@com_github_golang_glog//:glog", - "@com_github_hectane_go_acl//:go-acl", + "@org_golang_google_grpc//credentials", "@org_golang_google_grpc//credentials/oauth", - "@org_golang_google_protobuf//encoding/prototext", - "@org_golang_google_protobuf//types/known/timestamppb", "@org_golang_x_oauth2//:oauth2", - "@org_golang_x_oauth2//google", ], ) go_test( name = "credshelper_test", - srcs = [ - "credshelper_test.go", - ], + srcs = ["credshelper_test.go"], embed = [":credshelper"], - deps = [ - "@com_github_google_go_cmp//cmp", - "@org_golang_x_oauth2//:oauth2", - ], ) diff --git a/go/pkg/credshelper/credshelper.go b/go/pkg/credshelper/credshelper.go index 9289da6e9..ab2a32c62 100644 --- a/go/pkg/credshelper/credshelper.go +++ b/go/pkg/credshelper/credshelper.go @@ -18,6 +18,7 @@ import ( log "github.com/golang/glog" "golang.org/x/oauth2" + "google.golang.org/grpc/credentials" grpcOauth "google.golang.org/grpc/credentials/oauth" ) @@ -77,6 +78,16 @@ func (r *reusableCmd) Digest() digest.Digest { // Credentials provides auth functionalities using an external credentials helper type Credentials struct { tokenSource *grpcOauth.TokenSource + perRPCCreds *perRPCCredentials + credsHelperCmd *reusableCmd +} + +// perRPCCredentials fullfills the grpc.Credentials.PerRPCCredentials interface +// to provde auth functionalities with headers +type perRPCCredentials struct { + headers map[string]string + expiry time.Time + headersLock sync.RWMutex credsHelperCmd *reusableCmd } @@ -86,9 +97,6 @@ type Credentials struct { // oauth2.TokenSource and credentials.PerRPCCredentials interfaces. type externalTokenSource struct { credsHelperCmd *reusableCmd - headers map[string]string - expiry time.Time - headersLock sync.RWMutex } // TokenSource returns a token source for this credentials instance. @@ -99,6 +107,20 @@ func (c *Credentials) TokenSource() *grpcOauth.TokenSource { return c.tokenSource } +// PerRPCCreds returns a perRPCCredentials for this credentials instance. +func (c *Credentials) PerRPCCreds() credentials.PerRPCCredentials { + if c == nil { + return nil + } + // If no perRPCCreds exist for this Credentials object, then + // grpcOauth.TokenSource will do since it implements the same interface + // and some credentials helpers may only provide a token without headers + if c.perRPCCreds == nil { + return c.TokenSource() + } + return c.perRPCCreds +} + // Token retrieves an oauth2 token from the external tokensource. func (ts *externalTokenSource) Token() (*oauth2.Token, error) { if ts == nil { @@ -108,27 +130,30 @@ func (ts *externalTokenSource) Token() (*oauth2.Token, error) { if err != nil { return nil, err } + if credsOut.tk.AccessToken == "" { + return nil, fmt.Errorf("no token was printed by the credentials helper") + } log.Infof("'%s' credentials refreshed at %v, expires at %v", ts.credsHelperCmd, time.Now(), credsOut.tk.Expiry) return credsOut.tk, err } // GetRequestMetadata gets the current request metadata, refreshing tokens if required. -func (ts *externalTokenSource) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { - ts.headersLock.RLock() - defer ts.headersLock.RUnlock() - if ts.expiry.Before(nowFn().Add(-expiryBuffer)) { - credsOut, err := runCredsHelperCmd(ts.credsHelperCmd) +func (p *perRPCCredentials) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { + p.headersLock.RLock() + defer p.headersLock.RUnlock() + if p.expiry.Before(nowFn().Add(-expiryBuffer)) { + credsOut, err := runCredsHelperCmd(p.credsHelperCmd) if err != nil { return nil, err } - ts.expiry = credsOut.tk.Expiry - ts.headers = credsOut.hdrs + p.expiry = credsOut.tk.Expiry + p.headers = credsOut.hdrs } - return ts.headers, nil + return p.headers, nil } // RequireTransportSecurity indicates whether the credentials require transport security. -func (ts *externalTokenSource) RequireTransportSecurity() bool { +func (p *perRPCCredentials) RequireTransportSecurity() bool { return true } @@ -149,20 +174,29 @@ func NewExternalCredentials(credshelper string, credshelperArgs []string) (*Cred c := &Credentials{ credsHelperCmd: credsHelperCmd, } - baseTS := &externalTokenSource{ - credsHelperCmd: credsHelperCmd, + if len(credsOut.hdrs) != 0 { + c.perRPCCreds = &perRPCCredentials{ + headers: credsOut.hdrs, + expiry: credsOut.tk.Expiry, + credsHelperCmd: credsHelperCmd, + } } - c.tokenSource = &grpcOauth.TokenSource{ - // Wrap the base token source with a ReuseTokenSource so that we only - // generate new credentials when the current one is about to expire. - // This is needed because retrieving the token is expensive and some - // token providers have per hour rate limits. - TokenSource: oauth2.ReuseTokenSourceWithExpiry( - credsOut.tk, - baseTS, - // Refresh tokens a bit early to be safe - expiryBuffer, - ), + if credsOut.tk.AccessToken != "" { + baseTS := &externalTokenSource{ + credsHelperCmd: credsHelperCmd, + } + c.tokenSource = &grpcOauth.TokenSource{ + // Wrap the base token source with a ReuseTokenSource so that we only + // generate new credentials when the current one is about to expire. + // This is needed because retrieving the token is expensive and some + // token providers have per hour rate limits. + TokenSource: oauth2.ReuseTokenSourceWithExpiry( + credsOut.tk, + baseTS, + // Refresh tokens a bit early to be safe + expiryBuffer, + ), + } } return c, nil } @@ -170,7 +204,6 @@ func NewExternalCredentials(credshelper string, credshelperArgs []string) (*Cred type credshelperOutput struct { hdrs map[string]string tk *oauth2.Token - rexp time.Time } func runCredsHelperCmd(credsHelperCmd *reusableCmd) (*credshelperOutput, error) { @@ -203,8 +236,8 @@ func parseTokenExpiryFromOutput(out string) (*credshelperOutput, error) { if err := json.Unmarshal([]byte(out), &jsonOut); err != nil { return nil, fmt.Errorf("error while decoding credshelper output:%v", err) } - if jsonOut.Token == "" { - return nil, fmt.Errorf("no token was printed by the credentials helper") + if jsonOut.Token == "" && len(jsonOut.Headers) == 0 { + return nil, fmt.Errorf("both token and headers are empty, invalid credentials") } credsOut.tk = &oauth2.Token{AccessToken: jsonOut.Token} credsOut.hdrs = jsonOut.Headers diff --git a/go/pkg/credshelper/credshelper_test.go b/go/pkg/credshelper/credshelper_test.go index 4621652d6..c75a6f3c8 100644 --- a/go/pkg/credshelper/credshelper_test.go +++ b/go/pkg/credshelper/credshelper_test.go @@ -125,7 +125,6 @@ func TestNewExternalCredentials(t *testing.T) { credshelperOut: fmt.Sprintf(`{"token":"%v","expiry":""}`, testToken), }, { name: "No Token", - wantErr: true, credshelperOut: `{"headers":{"hdr":"val"},"token":"","expiry":""}`, }, { name: "Credshelper Command Passed - No Expiry", @@ -168,7 +167,7 @@ func TestNewExternalCredentials(t *testing.T) { if test.wantErr && err == nil { t.Fatalf("NewExternalCredentials did not return an error.") } - if !test.wantErr { + if !test.wantErr && test.name != "No Token" { if err != nil { t.Fatalf("NewExternalCredentials returned an error: %v", err) } @@ -247,13 +246,13 @@ func TestGetRequestMetadata(t *testing.T) { credshelperArgs = []string{test.credshelperOut} } credsHelperCmd := newReusableCmd(credshelper, credshelperArgs) - exTs := externalTokenSource{ + p := perRPCCredentials{ credsHelperCmd: credsHelperCmd, expiry: test.tsExp, headers: test.tsHeaders, headersLock: sync.RWMutex{}, } - hdrs, err := exTs.GetRequestMetadata(context.Background(), "uri") + hdrs, err := p.GetRequestMetadata(context.Background(), "uri") if test.wantErr && err == nil { t.Fatalf("GetRequestMetadata did not return an error.") } @@ -261,10 +260,10 @@ func TestGetRequestMetadata(t *testing.T) { if err != nil { t.Fatalf("GetRequestMetadata returned an error: %v", err) } - if !reflect.DeepEqual(hdrs, exTs.headers) { - t.Errorf("GetRequestMetadata did not update headers in the tokensource: returned hdrs: %v, tokensource headers: %v", hdrs, exTs.headers) + if !reflect.DeepEqual(hdrs, p.headers) { + t.Errorf("GetRequestMetadata did not update headers in the tokensource: returned hdrs: %v, tokensource headers: %v", hdrs, p.headers) } - if !exp.Equal(exTs.expiry) { + if !exp.Equal(p.expiry) { t.Errorf("GetRequestMetadata did not update expiry in the tokensource") } if !test.wantExpired && !reflect.DeepEqual(hdrs, testHdrs) { diff --git a/go/pkg/flags/flags.go b/go/pkg/flags/flags.go index ed6e74258..b10f140aa 100644 --- a/go/pkg/flags/flags.go +++ b/go/pkg/flags/flags.go @@ -125,7 +125,7 @@ func NewClientFromFlags(ctx context.Context, opts ...client.Opt) (*client.Client if err != nil { return nil, fmt.Errorf("credentials helper failed. Please try again or use another method of authentication:%v", err) } - perRPCCreds = &client.PerRPCCreds{Creds: creds.TokenSource()} + perRPCCreds = &client.PerRPCCreds{Creds: creds.PerRPCCreds()} } opts = tOpts