Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Lagoon API DB to determine project group membership #441

Merged
merged 8 commits into from
May 10, 2024
61 changes: 61 additions & 0 deletions internal/cache/cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Package cache implements a generic, thread-safe, in-memory cache.
package cache

import (
"sync"
"time"
)

const (
defaultTTL = time.Minute
)

// Cache is a generic, thread-safe, in-memory cache that stores a value with a
// TTL, after which the cache expires.
type Cache[T any] struct {
data T
expiry time.Time
ttl time.Duration
mu sync.Mutex
}

// Option is a functional option argument to NewCache().
type Option[T any] func(*Cache[T])

// WithTTL sets the the Cache time-to-live to ttl.
func WithTTL[T any](ttl time.Duration) Option[T] {
return func(c *Cache[T]) {
c.ttl = ttl
}
}

// NewCache instantiates a Cache for type T with a default TTL of 1 minute.
func NewCache[T any](options ...Option[T]) *Cache[T] {
c := Cache[T]{
ttl: defaultTTL,
}
for _, option := range options {
option(&c)
}
return &c
}

// Set updates the value in the cache and sets the expiry to now+TTL.
func (c *Cache[T]) Set(value T) {
c.mu.Lock()
defer c.mu.Unlock()
c.data = value
c.expiry = time.Now().Add(c.ttl)
}

// Get retrieves the value from the cache. If cache has expired, the second
// return value will be false.
func (c *Cache[T]) Get() (T, bool) {
c.mu.Lock()
defer c.mu.Unlock()
if time.Now().After(c.expiry) {
var zero T
return zero, false
}
return c.data, true
}
69 changes: 69 additions & 0 deletions internal/cache/cache_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package cache_test

import (
"testing"
"time"

"github.com/alecthomas/assert/v2"
"github.com/uselagoon/ssh-portal/internal/cache"
)

func TestIntCache(t *testing.T) {
var testCases = map[string]struct {
input int
expect int
expired bool
}{
"not expired": {input: 11, expect: 11},
"expired": {input: 11, expired: true},
}
for name, tc := range testCases {
t.Run(name, func(tt *testing.T) {
c := cache.NewCache[int](cache.WithTTL[int](time.Second))
c.Set(tc.input)
if tc.expired {
time.Sleep(2 * time.Second)
_, ok := c.Get()
assert.False(tt, ok, name)
} else {
value, ok := c.Get()
assert.True(tt, ok, name)
assert.Equal(tt, tc.expect, value, name)
}
})
}
}

func TestMapCache(t *testing.T) {
var testCases = map[string]struct {
input map[string]string
expect map[string]string
expired bool
}{
"expired": {
input: map[string]string{"foo": "bar"},
expired: true,
},
"not expired": {
input: map[string]string{"foo": "bar"},
expect: map[string]string{"foo": "bar"},
},
}
for name, tc := range testCases {
t.Run(name, func(tt *testing.T) {
c := cache.NewCache[map[string]string](
cache.WithTTL[map[string]string](time.Second),
)
c.Set(tc.input)
if tc.expired {
time.Sleep(2 * time.Second)
_, ok := c.Get()
assert.False(tt, ok, name)
} else {
value, ok := c.Get()
assert.True(tt, ok, name)
assert.Equal(tt, tc.expect, value, name)
}
})
}
}
10 changes: 9 additions & 1 deletion internal/keycloak/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"time"

"github.com/MicahParks/keyfunc/v2"
"github.com/uselagoon/ssh-portal/internal/cache"
oidcClient "github.com/zitadel/oidc/v3/pkg/client"
"github.com/zitadel/oidc/v3/pkg/oidc"
"golang.org/x/time/rate"
Expand All @@ -21,23 +22,28 @@ const pkgName = "github.com/uselagoon/ssh-portal/internal/keycloak"

// Client is a keycloak client.
type Client struct {
baseURL *url.URL
clientID string
clientSecret string
jwks *keyfunc.JWKS
log *slog.Logger
oidcConfig *oidc.DiscoveryConfiguration
limiter *rate.Limiter

// groupNameGroupIDMap cache
groupCache *cache.Cache[map[string]string]
}

// NewClient creates a new keycloak client for the lagoon realm.
func NewClient(ctx context.Context, log *slog.Logger, keycloakURL, clientID,
clientSecret string, rateLimit int) (*Client, error) {
// discover OIDC config
issuerURL, err := url.Parse(keycloakURL)
baseURL, err := url.Parse(keycloakURL)
if err != nil {
return nil, fmt.Errorf("couldn't parse keycloak base URL %s: %v",
keycloakURL, err)
}
issuerURL := *baseURL
issuerURL.Path = path.Join(issuerURL.Path, "auth/realms/lagoon")
oidcConfig, err := oidcClient.Discover(ctx, issuerURL.String(),
&http.Client{Timeout: 8 * time.Second})
Expand All @@ -50,11 +56,13 @@ func NewClient(ctx context.Context, log *slog.Logger, keycloakURL, clientID,
return nil, fmt.Errorf("couldn't get keycloak lagoon realm JWKS: %v", err)
}
return &Client{
baseURL: baseURL,
clientID: clientID,
clientSecret: clientSecret,
jwks: jwks,
log: log,
oidcConfig: oidcConfig,
limiter: rate.NewLimiter(rate.Limit(rateLimit), rateLimit),
groupCache: cache.NewCache[map[string]string](),
}, nil
}
82 changes: 82 additions & 0 deletions internal/keycloak/groups.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package keycloak

import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"path"

"golang.org/x/oauth2/clientcredentials"
)

// Group represents a Keycloak Group. It holds the fields required when getting
// a list of groups from keycloak.
type Group struct {
ID string `json:"id"`
Name string `json:"name"`
}

func (c *Client) httpClient(ctx context.Context) *http.Client {
cc := clientcredentials.Config{
ClientID: c.clientID,
ClientSecret: c.clientSecret,
TokenURL: c.oidcConfig.TokenEndpoint,
}
return cc.Client(ctx)
}

// rawGroups returns the raw JSON group representation from the Keycloak API.
func (c *Client) rawGroups(ctx context.Context) ([]byte, error) {
groupsURL := *c.baseURL
groupsURL.Path = path.Join(c.baseURL.Path,
"/auth/admin/realms/lagoon/groups")
req, err := http.NewRequestWithContext(ctx, "GET", groupsURL.String(), nil)
if err != nil {
return nil, fmt.Errorf("couldn't construct groups request: %v", err)
}
q := req.URL.Query()
q.Add("briefRepresentation", "true")
req.URL.RawQuery = q.Encode()
res, err := c.httpClient(ctx).Do(req)
if err != nil {
return nil, fmt.Errorf("couldn't get groups: %v", err)
}
defer res.Body.Close()
if res.StatusCode > 299 {
body, _ := io.ReadAll(res.Body)
return nil, fmt.Errorf("bad groups response: %d\n%s", res.StatusCode, body)
}
return io.ReadAll(res.Body)
}

// GroupNameGroupIDMap returns a map of Keycloak Group names to Group IDs.
func (c *Client) GroupNameGroupIDMap(
ctx context.Context,
) (map[string]string, error) {
// rate limit keycloak API access
if err := c.limiter.Wait(ctx); err != nil {
return nil, fmt.Errorf("couldn't wait for limiter: %v", err)
}
// prefer to use cached value
if groupNameGroupIDMap, ok := c.groupCache.Get(); ok {
return groupNameGroupIDMap, nil
}
// otherwise get data from keycloak
data, err := c.rawGroups(ctx)
if err != nil {
return nil, fmt.Errorf("couldn't get groups from Keycloak API: %v", err)
}
var groups []Group
if err := json.Unmarshal(data, &groups); err != nil {
return nil, fmt.Errorf("couldn't unmarshal Keycloak groups: %v", err)
}
groupNameGroupIDMap := map[string]string{}
for _, group := range groups {
groupNameGroupIDMap[group.Name] = group.ID
}
// update cache
c.groupCache.Set(groupNameGroupIDMap)
return groupNameGroupIDMap, nil
}
36 changes: 3 additions & 33 deletions internal/keycloak/jwt.go
Original file line number Diff line number Diff line change
@@ -1,47 +1,17 @@
package keycloak

import (
"encoding/json"
"fmt"

"github.com/golang-jwt/jwt/v5"
"golang.org/x/oauth2"
)

type groupProjectIDs map[string][]int

func (gpids *groupProjectIDs) UnmarshalJSON(data []byte) error {
// unmarshal the double-encoded group-pid attributes
var gpas []string
if err := json.Unmarshal(data, &gpas); err != nil {
return err
}
// convert the slice of encoded group-pid attributes into a slice of
// group-pid maps
var gpms []map[string][]int
for _, gpa := range gpas {
var gpm map[string][]int
if err := json.Unmarshal([]byte(gpa), &gpm); err != nil {
return err
}
gpms = append(gpms, gpm)
}
// flatten the slice of group-pid maps into a single map
*gpids = groupProjectIDs{}
for _, gpm := range gpms {
for k, v := range gpm {
(*gpids)[k] = v
}
}
return nil
}

// LagoonClaims contains the token claims used by Lagoon.
type LagoonClaims struct {
RealmRoles []string `json:"realm_roles"`
UserGroups []string `json:"group_membership"`
GroupProjectIDs groupProjectIDs `json:"group_lagoon_project_ids"`
AuthorizedParty string `json:"azp"`
RealmRoles []string `json:"realm_roles"`
UserGroups []string `json:"group_membership"`
AuthorizedParty string `json:"azp"`
jwt.RegisteredClaims

clientID string `json:"-"`
Expand Down
14 changes: 2 additions & 12 deletions internal/keycloak/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,8 @@ func TestUnmarshalLagoonClaims(t *testing.T) {
"{\"credentialtest-group1\":[1]}",
"{\"ci-group\":[3,4,5,6,7,8,9,10,11,12,17,14,16,20,21,24,19,23,31]}"]}`),
expect: &keycloak.LagoonClaims{
RealmRoles: nil,
UserGroups: nil,
GroupProjectIDs: map[string][]int{
"credentialtest-group1": {1},
"ci-group": {3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 17, 14, 16, 20, 21, 24,
19, 23, 31},
},
RealmRoles: nil,
UserGroups: nil,
RegisteredClaims: jwt.RegisteredClaims{},
},
},
Expand Down Expand Up @@ -97,11 +92,6 @@ func TestUnmarshalLagoonClaims(t *testing.T) {
UserGroups: []string{
"/ci-group/ci-group-owner",
"/credentialtest-group1/credentialtest-group1-owner"},
GroupProjectIDs: map[string][]int{
"credentialtest-group1": {1},
"ci-group": {3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 17, 14, 16, 20, 21, 24,
19, 23, 31},
},
AuthorizedParty: "service-api",
RegisteredClaims: jwt.RegisteredClaims{
ID: "ba279e79-4f38-43ae-83e7-fe461aad59d1",
Expand Down
13 changes: 6 additions & 7 deletions internal/keycloak/userrolesandgroups.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,15 @@ import (
)

// UserRolesAndGroups queries Keycloak given the user UUID, and returns the
// user's realm roles, group memberships, and the project IDs associated with
// those groups.
// user's realm roles, and group memberships (by name, including subgroups).
func (c *Client) UserRolesAndGroups(ctx context.Context,
userUUID *uuid.UUID) ([]string, []string, map[string][]int, error) {
userUUID *uuid.UUID) ([]string, []string, error) {
// set up tracing
ctx, span := otel.Tracer(pkgName).Start(ctx, "UserRolesAndGroups")
defer span.End()
// rate limit keycloak API access
if err := c.limiter.Wait(ctx); err != nil {
return nil, nil, nil, fmt.Errorf("couldn't wait for limiter: %v", err)
return nil, nil, fmt.Errorf("couldn't wait for limiter: %v", err)
}
// get user token
userConfig := oauth2.Config{
Expand All @@ -41,12 +40,12 @@ func (c *Client) UserRolesAndGroups(ctx context.Context,
// https://www.keycloak.org/docs/latest/securing_apps/#_token-exchange
oauth2.SetAuthURLParam("requested_subject", userUUID.String()))
if err != nil {
return nil, nil, nil, fmt.Errorf("couldn't get user token: %v", err)
return nil, nil, fmt.Errorf("couldn't get user token: %v", err)
}
// parse and extract verified attributes
claims, err := c.parseAccessToken(userToken, userUUID.String())
if err != nil {
return nil, nil, nil, fmt.Errorf("couldn't parse user access token: %v", err)
return nil, nil, fmt.Errorf("couldn't parse user access token: %v", err)
}
return claims.RealmRoles, claims.UserGroups, claims.GroupProjectIDs, nil
return claims.RealmRoles, claims.UserGroups, nil
}
Loading
Loading