diff --git a/sessions/TODO b/sessions/TODO new file mode 100644 index 0000000..16f8b4f --- /dev/null +++ b/sessions/TODO @@ -0,0 +1,4 @@ +- Update Session data to use defined types instead of string map +- Create type assertions for Session data +- Refactor functions to use newly defined types + assertions +- Phase in new_context.go functions over context.go \ No newline at end of file diff --git a/sessions/new_context.go b/sessions/new_context.go new file mode 100644 index 0000000..528db09 --- /dev/null +++ b/sessions/new_context.go @@ -0,0 +1,92 @@ +package sessions + +import ( + "context" + + "github.com/pkg/errors" + + "github.com/theopenlane/utils/contextx" + "golang.org/x/oauth2" +) + +// NewContextWithToken returns a copy of ctx that stores the Token +func NewContextWithToken(ctx context.Context, token *oauth2.Token) context.Context { + return contextx.With(ctx, token) +} + +// NewOhAuthTokenFromContext returns the Token from the ctx +func NewOhAuthTokenFromContext(ctx context.Context) (*oauth2.Token, error) { + token, ok := contextx.From[*oauth2.Token](ctx) + if !ok { + return nil, errors.New("context missing Token") + } + + return token, nil +} + +// NewUserIDFromContext returns the user ID from the ctx +// this function assumes the session data is stored in a string map +func NewUserIDFromContext(ctx context.Context) (string, error) { + sessionDetails, ok := contextx.From[*Session[any]](ctx) + if !ok { + return "", ErrInvalidSession + } + + sessionID := sessionDetails.GetKey() + + sessionData, ok := sessionDetails.GetOk(sessionID) + if !ok { + return "", ErrInvalidSession + } + + sd, ok := sessionData.(map[string]string) + if !ok { + return "", ErrInvalidSession + } + + userID, ok := sd["userID"] + if !ok { + return "", ErrInvalidSession + } + + return userID, nil +} + +type UserID string + +// NewContextWithUserID returns a copy of ctx that stores the user ID +func NewContextWithUserID(ctx context.Context, userID UserID) context.Context { + if userID == "" { + return ctx + } + + return contextx.With(ctx, userID) +} + +// NewSessionToken returns the session token from the context +func NewSessionToken(ctx context.Context) (string, error) { + sd, err := newGetSessionDataFromContext(ctx) + if err != nil { + return "", err + } + + sd.mu.Lock() + defer sd.mu.Unlock() + + return sd.store.EncodeCookie(sd) +} + +// NewAddSessionDataToContext adds session data to the context +func (s *Session[P]) NewAddSessionDataToContext(ctx context.Context) context.Context { + return contextx.With(ctx, s) +} + +// newGetSessionDataFromContext retrieves session data from the context +func newGetSessionDataFromContext(ctx context.Context) (*Session[map[string]any], error) { + sessionData, ok := contextx.From[*Session[map[string]any]](ctx) + if !ok { + return nil, errors.New("context missing session data") + } + + return sessionData, nil +} diff --git a/sessions/newcontext_test.go b/sessions/newcontext_test.go new file mode 100644 index 0000000..c1c6c2a --- /dev/null +++ b/sessions/newcontext_test.go @@ -0,0 +1,74 @@ +package sessions_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/theopenlane/utils/contextx" + "golang.org/x/oauth2" + + "github.com/theopenlane/iam/sessions" +) + +func TestNewContextWithToken(t *testing.T) { + ctx := context.Background() + token := &oauth2.Token{AccessToken: "test_token"} + + ctx = sessions.NewContextWithToken(ctx, token) + retrievedToken, ok := contextx.From[*oauth2.Token](ctx) + + assert.True(t, ok) + assert.Equal(t, token, retrievedToken) +} + +func TestNewOhAuthTokenFromContext(t *testing.T) { + ctx := context.Background() + token := &oauth2.Token{AccessToken: "test_token"} + + ctx = sessions.NewContextWithToken(ctx, token) + retrievedToken, err := sessions.NewOhAuthTokenFromContext(ctx) + + assert.NoError(t, err) + assert.Equal(t, token, retrievedToken) +} + +func TestNewOhAuthTokenFromContext_MissingToken(t *testing.T) { + ctx := context.Background() + + _, err := sessions.NewOhAuthTokenFromContext(ctx) + + assert.Error(t, err) + assert.Equal(t, "context missing Token", err.Error()) +} + +func TestNewUserIDFromContext_MissingSession(t *testing.T) { + ctx := context.Background() + + _, err := sessions.NewUserIDFromContext(ctx) + + assert.Error(t, err) + assert.Equal(t, sessions.ErrInvalidSession, err) +} + +func TestNewContextWithUserID(t *testing.T) { + ctx := context.Background() + userID := sessions.UserID("test_user") + + ctx = sessions.NewContextWithUserID(ctx, userID) + retrievedUserID, ok := contextx.From[sessions.UserID](ctx) + + assert.True(t, ok) + assert.Equal(t, userID, retrievedUserID) +} + +func TestNewContextWithUserID_EmptyUserID(t *testing.T) { + ctx := context.Background() + userID := sessions.UserID("") + + ctx = sessions.NewContextWithUserID(ctx, userID) + retrievedUserID, ok := contextx.From[sessions.UserID](ctx) + + assert.False(t, ok) + assert.Equal(t, sessions.UserID(""), retrievedUserID) +}