Skip to content

Commit

Permalink
test: login path (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
tenstad authored Sep 26, 2023
1 parent e3030be commit c96bb40
Show file tree
Hide file tree
Showing 5 changed files with 203 additions and 54 deletions.
33 changes: 27 additions & 6 deletions internal/jwtauth/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"time"

"github.com/hashicorp/vault-client-go"
"github.com/hashicorp/vault-client-go/schema"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"

Expand Down Expand Up @@ -34,7 +35,11 @@ type multiroleJWTAuthBackend struct {
l sync.RWMutex
cachedConfig *multiroleJWTConfig
roleIndex *roleIndex
client *vault.Client
policyClient policyFetcher
}

type policyFetcher interface {
policies(ctx context.Context, request schema.JwtLoginRequest) ([]string, error)
}

func backend(_ *logical.BackendConfig) *multiroleJWTAuthBackend {
Expand Down Expand Up @@ -78,12 +83,12 @@ func (b *multiroleJWTAuthBackend) getRoleIndex(config *multiroleJWTConfig) (*rol
return index, nil
}

func (b *multiroleJWTAuthBackend) getClient(config *multiroleJWTConfig) (*vault.Client, error) {
func (b *multiroleJWTAuthBackend) getPolicyClient(config *multiroleJWTConfig) (policyFetcher, error) {
b.l.Lock()
defer b.l.Unlock()

if b.client != nil {
return b.client, nil
if b.policyClient != nil {
return b.policyClient, nil
}

client, err := vault.New(
Expand All @@ -94,6 +99,22 @@ func (b *multiroleJWTAuthBackend) getClient(config *multiroleJWTConfig) (*vault.
return nil, fmt.Errorf("failed to create vault client: %w", err)
}

b.client = client
return client, nil
b.policyClient = &vaultClient{
Client: client,
mountPath: config.JWTAuthPath,
}
return b.policyClient, nil
}

type vaultClient struct {
*vault.Client
mountPath string
}

func (c *vaultClient) policies(ctx context.Context, request schema.JwtLoginRequest) ([]string, error) {
r, err := c.Client.Auth.JwtLogin(ctx, request, vault.WithMountPath(c.mountPath))
if err != nil {
return nil, fmt.Errorf("vault error: %w", err)
}
return r.Auth.Policies, nil
}
54 changes: 54 additions & 0 deletions internal/jwtauth/backend_utils_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package jwtauth

import (
"context"
"testing"
"time"

log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/helper/logging"
"github.com/hashicorp/vault/sdk/logical"
)

func createTestBackend(t *testing.T) (*multiroleJWTAuthBackend, logical.Storage) {
config := &logical.BackendConfig{
Logger: logging.NewVaultLogger(log.Trace),

System: &logical.StaticSystemView{
DefaultLeaseTTLVal: time.Hour * 12,
MaxLeaseTTLVal: time.Hour * 24,
},
StorageView: &logical.InmemStorage{},
}

logicalBackend, err := Factory(context.Background(), config)
if err != nil {
t.Fatalf("unable to create backend: %v", err)
}

backend, ok := logicalBackend.(*multiroleJWTAuthBackend)
if !ok {
t.Fatal("backend is not a multiroleJWTAuthBackend")
}

return backend, config.StorageView
}

func testConfig() map[string]any {
return map[string]any{
"roles": map[string]any{
"foo": map[string]any{
"project_path": []any{"foo", "bar"},
},
"bar": map[string]any{
"namespace_path": []any{"c"},
"ref": []any{"master", "main"},
},
"baz": map[string]any{
"project_path": []any{"baz"},
},
},
"jwt_auth_host": "http://localhost:8200",
"jwt_auth_path": "foo/jwt",
}
}
49 changes: 6 additions & 43 deletions internal/jwtauth/path_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,38 +4,35 @@ import (
"context"
"reflect"
"testing"
"time"

"github.com/go-test/deep"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/helper/logging"
"github.com/hashicorp/vault/sdk/logical"
)

func TestConfig_Write(t *testing.T) {
t.Parallel()
backend, storage := createTestBackend(t)

data := testConfig()
configData := testConfig()
req := &logical.Request{
Operation: logical.UpdateOperation,
Path: configPath,
Storage: storage,
Data: data,
Data: configData,
}

resp, err := backend.HandleRequest(context.Background(), req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}

conf, err := backend.(*multiroleJWTAuthBackend).config(context.Background(), storage)
conf, err := backend.config(context.Background(), storage)
if err != nil {
t.Fatal(err)
}

expected := &multiroleJWTConfig{
Roles: data["roles"].(map[string]any),
Roles: configData["roles"].(map[string]any),
JWTAuthHost: "http://localhost:8200",
JWTAuthPath: "foo/jwt",
}
Expand Down Expand Up @@ -96,7 +93,7 @@ func TestConfig_Delete(t *testing.T) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}

conf, err := backend.(*multiroleJWTAuthBackend).config(context.Background(), storage)
conf, err := backend.config(context.Background(), storage)
if err != nil {
t.Fatal(err)
}
Expand All @@ -116,45 +113,11 @@ func TestConfig_Delete(t *testing.T) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}

conf, err = backend.(*multiroleJWTAuthBackend).config(context.Background(), storage)
conf, err = backend.config(context.Background(), storage)
if err != nil {
t.Fatal(err)
}
if conf != nil {
t.Fatal("expected config to not exist after delete")
}
}

func createTestBackend(t *testing.T) (logical.Backend, logical.Storage) {
config := &logical.BackendConfig{
Logger: logging.NewVaultLogger(log.Trace),

System: &logical.StaticSystemView{
DefaultLeaseTTLVal: time.Hour * 12,
MaxLeaseTTLVal: time.Hour * 24,
},
StorageView: &logical.InmemStorage{},
}
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatalf("unable to create backend: %v", err)
}

return b, config.StorageView
}

func testConfig() map[string]any {
return map[string]any{
"roles": map[string]any{
"foo": map[string]any{
"project_path": []any{"foo", "bar"},
},
"bar": map[string]any{
"namespace_path": []any{"c"},
"ref": []any{"master", "main"},
},
},
"jwt_auth_host": "http://localhost:8200",
"jwt_auth_path": "foo/jwt",
}
}
9 changes: 4 additions & 5 deletions internal/jwtauth/path_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"time"

jwt "github.com/golang-jwt/jwt/v5"
"github.com/hashicorp/vault-client-go"
"github.com/hashicorp/vault-client-go/schema"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
Expand Down Expand Up @@ -86,24 +85,24 @@ func (b *multiroleJWTAuthBackend) pathLogin(
func (b *multiroleJWTAuthBackend) policies(
ctx context.Context, config *multiroleJWTConfig, roles []string, token string,
) ([]string, error) {
client, err := b.getClient(config)
client, err := b.getPolicyClient(config)
if err != nil {
return nil, err
}

policies := map[string]struct{}{}
for _, role := range roles {
response, err := client.Auth.JwtLogin(ctx, schema.JwtLoginRequest{
rolePolicies, err := client.policies(ctx, schema.JwtLoginRequest{
Jwt: token,
Role: role,
}, vault.WithMountPath(config.JWTAuthPath))
})
if err != nil {
continue
// TODO: return error if non-403
// return nil, err
}

for _, p := range response.Auth.Policies {
for _, p := range rolePolicies {
policies[p] = struct{}{}
}
}
Expand Down
112 changes: 112 additions & 0 deletions internal/jwtauth/path_login_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package jwtauth

import (
"context"
"crypto/rand"
"crypto/rsa"
"testing"

jwt "github.com/golang-jwt/jwt/v5"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/hashicorp/vault-client-go/schema"
"github.com/hashicorp/vault/sdk/logical"
)

func TestLogin_Write(t *testing.T) {
t.Parallel()
backend, storage := createTestBackend(t)

configData := map[string]any{
"roles": map[string]any{
"foo": map[string]any{
"project_path": []any{"foo"},
},
"foobar": map[string]any{
"namespace_path": []any{"ns"},
"project_path": []any{"foo", "bar"},
},
"baz": map[string]any{
"project_path": []any{"baz"},
},
},
"jwt_auth_host": "http://localhost:8200",
"jwt_auth_path": "jwt",
}
req := &logical.Request{
Operation: logical.UpdateOperation,
Path: "config",
Storage: storage,
Data: configData,
}

resp, err := backend.HandleRequest(context.Background(), req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}

var policyClient fakePolicyFetcher = func(_ context.Context, request schema.JwtLoginRequest) ([]string, error) {
return []string{request.Role + "-policy"}, nil
}
backend.policyClient = policyClient

privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("unable to create private key: %s", err.Error())
}

for i, tt := range []struct {
claims jwt.MapClaims
policies []string
}{
{
claims: jwt.MapClaims{
"project_path": "foo",
},
policies: []string{"foo-policy"},
},
{
claims: jwt.MapClaims{
"project_path": "foo",
"namespace_path": "ns",
},
policies: []string{"foo-policy", "foobar-policy"},
},
{
claims: jwt.MapClaims{
"project_path": "baz",
},
policies: []string{"baz-policy"},
},
} {
token, err := jwt.NewWithClaims(jwt.SigningMethodPS512, tt.claims).SignedString(privateKey)
if err != nil {
t.Fatalf("unable to sign jwt: %s", err.Error())
}

req = &logical.Request{
Operation: logical.UpdateOperation,
Path: "login",
Storage: storage,
Data: map[string]any{
"jwt": token,
},
}

resp, err = backend.HandleRequest(context.Background(), req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}

opt := cmpopts.SortSlices(func(a, b string) bool { return a < b })
if diff := cmp.Diff(resp.Auth.Policies, tt.policies, opt); diff != "" {
t.Fatalf("Test case %v failed with diff:\n%s", i, diff)
}
}
}

type fakePolicyFetcher func(context.Context, schema.JwtLoginRequest) ([]string, error)

func (c fakePolicyFetcher) policies(ctx context.Context, request schema.JwtLoginRequest) ([]string, error) {
return c(ctx, request)
}

0 comments on commit c96bb40

Please sign in to comment.