Skip to content

Commit

Permalink
Separate PerRPCCreds from Oauth creds to eliminate undefined behavior (
Browse files Browse the repository at this point in the history
  • Loading branch information
banikharbanda authored Nov 20, 2024
1 parent 3d9543e commit 8ff3cc7
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 57 deletions.
7 changes: 2 additions & 5 deletions go/pkg/balancer/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
load("@io_bazel_rules_go//go:def.bzl", "go_library")

go_library(
name = "balancer",
srcs = [
"roundrobin.go",
],
srcs = ["roundrobin.go"],
importpath = "github.com/bazelbuild/remote-apis-sdks/go/pkg/balancer",
visibility = ["//visibility:public"],
deps = [
"@org_golang_google_grpc//:go_default_library",
],
)

20 changes: 4 additions & 16 deletions go/pkg/credshelper/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,20 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")

go_library(
name = "credshelper",
srcs = [
"credshelper.go",
],
srcs = ["credshelper.go"],
importpath = "github.com/bazelbuild/remote-apis-sdks/go/pkg/credshelper",
visibility = ["//visibility:public"],
deps = [
"//go/api/credshelper",
"//go/pkg/digest",
"//go/pkg/digest",
"@com_github_golang_glog//:glog",
"@com_github_hectane_go_acl//:go-acl",
"@org_golang_google_grpc//credentials",
"@org_golang_google_grpc//credentials/oauth",
"@org_golang_google_protobuf//encoding/prototext",
"@org_golang_google_protobuf//types/known/timestamppb",
"@org_golang_x_oauth2//:oauth2",
"@org_golang_x_oauth2//google",
],
)

go_test(
name = "credshelper_test",
srcs = [
"credshelper_test.go",
],
srcs = ["credshelper_test.go"],
embed = [":credshelper"],
deps = [
"@com_github_google_go_cmp//cmp",
"@org_golang_x_oauth2//:oauth2",
],
)
89 changes: 61 additions & 28 deletions go/pkg/credshelper/credshelper.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (

log "github.com/golang/glog"
"golang.org/x/oauth2"
"google.golang.org/grpc/credentials"
grpcOauth "google.golang.org/grpc/credentials/oauth"
)

Expand Down Expand Up @@ -77,6 +78,16 @@ func (r *reusableCmd) Digest() digest.Digest {
// Credentials provides auth functionalities using an external credentials helper
type Credentials struct {
tokenSource *grpcOauth.TokenSource
perRPCCreds *perRPCCredentials
credsHelperCmd *reusableCmd
}

// perRPCCredentials fullfills the grpc.Credentials.PerRPCCredentials interface
// to provde auth functionalities with headers
type perRPCCredentials struct {
headers map[string]string
expiry time.Time
headersLock sync.RWMutex
credsHelperCmd *reusableCmd
}

Expand All @@ -86,9 +97,6 @@ type Credentials struct {
// oauth2.TokenSource and credentials.PerRPCCredentials interfaces.
type externalTokenSource struct {
credsHelperCmd *reusableCmd
headers map[string]string
expiry time.Time
headersLock sync.RWMutex
}

// TokenSource returns a token source for this credentials instance.
Expand All @@ -99,6 +107,20 @@ func (c *Credentials) TokenSource() *grpcOauth.TokenSource {
return c.tokenSource
}

// PerRPCCreds returns a perRPCCredentials for this credentials instance.
func (c *Credentials) PerRPCCreds() credentials.PerRPCCredentials {
if c == nil {
return nil
}
// If no perRPCCreds exist for this Credentials object, then
// grpcOauth.TokenSource will do since it implements the same interface
// and some credentials helpers may only provide a token without headers
if c.perRPCCreds == nil {
return c.TokenSource()
}
return c.perRPCCreds
}

// Token retrieves an oauth2 token from the external tokensource.
func (ts *externalTokenSource) Token() (*oauth2.Token, error) {
if ts == nil {
Expand All @@ -108,27 +130,30 @@ func (ts *externalTokenSource) Token() (*oauth2.Token, error) {
if err != nil {
return nil, err
}
if credsOut.tk.AccessToken == "" {
return nil, fmt.Errorf("no token was printed by the credentials helper")
}
log.Infof("'%s' credentials refreshed at %v, expires at %v", ts.credsHelperCmd, time.Now(), credsOut.tk.Expiry)
return credsOut.tk, err
}

// GetRequestMetadata gets the current request metadata, refreshing tokens if required.
func (ts *externalTokenSource) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
ts.headersLock.RLock()
defer ts.headersLock.RUnlock()
if ts.expiry.Before(nowFn().Add(-expiryBuffer)) {
credsOut, err := runCredsHelperCmd(ts.credsHelperCmd)
func (p *perRPCCredentials) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
p.headersLock.RLock()
defer p.headersLock.RUnlock()
if p.expiry.Before(nowFn().Add(-expiryBuffer)) {
credsOut, err := runCredsHelperCmd(p.credsHelperCmd)
if err != nil {
return nil, err
}
ts.expiry = credsOut.tk.Expiry
ts.headers = credsOut.hdrs
p.expiry = credsOut.tk.Expiry
p.headers = credsOut.hdrs
}
return ts.headers, nil
return p.headers, nil
}

// RequireTransportSecurity indicates whether the credentials require transport security.
func (ts *externalTokenSource) RequireTransportSecurity() bool {
func (p *perRPCCredentials) RequireTransportSecurity() bool {
return true
}

Expand All @@ -149,28 +174,36 @@ func NewExternalCredentials(credshelper string, credshelperArgs []string) (*Cred
c := &Credentials{
credsHelperCmd: credsHelperCmd,
}
baseTS := &externalTokenSource{
credsHelperCmd: credsHelperCmd,
if len(credsOut.hdrs) != 0 {
c.perRPCCreds = &perRPCCredentials{
headers: credsOut.hdrs,
expiry: credsOut.tk.Expiry,
credsHelperCmd: credsHelperCmd,
}
}
c.tokenSource = &grpcOauth.TokenSource{
// Wrap the base token source with a ReuseTokenSource so that we only
// generate new credentials when the current one is about to expire.
// This is needed because retrieving the token is expensive and some
// token providers have per hour rate limits.
TokenSource: oauth2.ReuseTokenSourceWithExpiry(
credsOut.tk,
baseTS,
// Refresh tokens a bit early to be safe
expiryBuffer,
),
if credsOut.tk.AccessToken != "" {
baseTS := &externalTokenSource{
credsHelperCmd: credsHelperCmd,
}
c.tokenSource = &grpcOauth.TokenSource{
// Wrap the base token source with a ReuseTokenSource so that we only
// generate new credentials when the current one is about to expire.
// This is needed because retrieving the token is expensive and some
// token providers have per hour rate limits.
TokenSource: oauth2.ReuseTokenSourceWithExpiry(
credsOut.tk,
baseTS,
// Refresh tokens a bit early to be safe
expiryBuffer,
),
}
}
return c, nil
}

type credshelperOutput struct {
hdrs map[string]string
tk *oauth2.Token
rexp time.Time
}

func runCredsHelperCmd(credsHelperCmd *reusableCmd) (*credshelperOutput, error) {
Expand Down Expand Up @@ -203,8 +236,8 @@ func parseTokenExpiryFromOutput(out string) (*credshelperOutput, error) {
if err := json.Unmarshal([]byte(out), &jsonOut); err != nil {
return nil, fmt.Errorf("error while decoding credshelper output:%v", err)
}
if jsonOut.Token == "" {
return nil, fmt.Errorf("no token was printed by the credentials helper")
if jsonOut.Token == "" && len(jsonOut.Headers) == 0 {
return nil, fmt.Errorf("both token and headers are empty, invalid credentials")
}
credsOut.tk = &oauth2.Token{AccessToken: jsonOut.Token}
credsOut.hdrs = jsonOut.Headers
Expand Down
13 changes: 6 additions & 7 deletions go/pkg/credshelper/credshelper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ func TestNewExternalCredentials(t *testing.T) {
credshelperOut: fmt.Sprintf(`{"token":"%v","expiry":""}`, testToken),
}, {
name: "No Token",
wantErr: true,
credshelperOut: `{"headers":{"hdr":"val"},"token":"","expiry":""}`,
}, {
name: "Credshelper Command Passed - No Expiry",
Expand Down Expand Up @@ -168,7 +167,7 @@ func TestNewExternalCredentials(t *testing.T) {
if test.wantErr && err == nil {
t.Fatalf("NewExternalCredentials did not return an error.")
}
if !test.wantErr {
if !test.wantErr && test.name != "No Token" {
if err != nil {
t.Fatalf("NewExternalCredentials returned an error: %v", err)
}
Expand Down Expand Up @@ -247,24 +246,24 @@ func TestGetRequestMetadata(t *testing.T) {
credshelperArgs = []string{test.credshelperOut}
}
credsHelperCmd := newReusableCmd(credshelper, credshelperArgs)
exTs := externalTokenSource{
p := perRPCCredentials{
credsHelperCmd: credsHelperCmd,
expiry: test.tsExp,
headers: test.tsHeaders,
headersLock: sync.RWMutex{},
}
hdrs, err := exTs.GetRequestMetadata(context.Background(), "uri")
hdrs, err := p.GetRequestMetadata(context.Background(), "uri")
if test.wantErr && err == nil {
t.Fatalf("GetRequestMetadata did not return an error.")
}
if !test.wantErr {
if err != nil {
t.Fatalf("GetRequestMetadata returned an error: %v", err)
}
if !reflect.DeepEqual(hdrs, exTs.headers) {
t.Errorf("GetRequestMetadata did not update headers in the tokensource: returned hdrs: %v, tokensource headers: %v", hdrs, exTs.headers)
if !reflect.DeepEqual(hdrs, p.headers) {
t.Errorf("GetRequestMetadata did not update headers in the tokensource: returned hdrs: %v, tokensource headers: %v", hdrs, p.headers)
}
if !exp.Equal(exTs.expiry) {
if !exp.Equal(p.expiry) {
t.Errorf("GetRequestMetadata did not update expiry in the tokensource")
}
if !test.wantExpired && !reflect.DeepEqual(hdrs, testHdrs) {
Expand Down
2 changes: 1 addition & 1 deletion go/pkg/flags/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func NewClientFromFlags(ctx context.Context, opts ...client.Opt) (*client.Client
if err != nil {
return nil, fmt.Errorf("credentials helper failed. Please try again or use another method of authentication:%v", err)
}
perRPCCreds = &client.PerRPCCreds{Creds: creds.TokenSource()}
perRPCCreds = &client.PerRPCCreds{Creds: creds.PerRPCCreds()}
}
opts = tOpts

Expand Down

0 comments on commit 8ff3cc7

Please sign in to comment.