Skip to content

Commit

Permalink
fix(auth): replace hardcoded OAuth endpoints with OIDC Discovery (#57)
Browse files Browse the repository at this point in the history
* docs: update OIDC redirect URL path

* fix(auth): replace hardcoded OAuth endpoints with OIDC Discovery

* fix(auth): remove redundant userinfo HTTP request

Previously, we required OIDC providers to support userInfo endpoints,
which created unnecessary configuration burden for users (especially
Authentik). Since we only need to validate the token, we now skip
the userInfo request entirely and use the session data we already have.

* feat(web): improve auth state handling

- Add isChecking state to prevent concurrent auth checks
- Remove debounce function as it's no longer needed
- Improve error handling and auth state management

* fix(web): fix plex timeout cleanup in useServiceData

Capture timeout ref value when effect runs to prevent potential race
conditions during cleanup. This ensures we're cleaning up the correct
Plex session polling timeout.

* feat(auth): remove hardcoded OIDC endpoint fallbacks

Remove fallback endpoint logic to strictly follow the OpenID Connect spec.
Improve error messaging by providing a link to the official discovery
specification.

* fix: tests
  • Loading branch information
s0up4200 authored Nov 18, 2024
1 parent 9e1d924 commit 051ccaa
Show file tree
Hide file tree
Showing 10 changed files with 386 additions and 321 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ Required OIDC environment variables:
OIDC_ISSUER=https://your-provider.com
OIDC_CLIENT_ID=your-client-id
OIDC_CLIENT_SECRET=your-client-secret
OIDC_REDIRECT_URL=http://localhost:3000/auth/callback
OIDC_REDIRECT_URL=http://localhost:3000/api/auth/callback
```

## Tech Stack
Expand Down
2 changes: 1 addition & 1 deletion docs/env_vars.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,5 +99,5 @@

- `OIDC_REDIRECT_URL`
- Purpose: Callback URL for OIDC authentication
- Example: `http://localhost:3000/auth/callback`
- Example: `http://localhost:3000/api/auth/callback`
- Required if using OIDC
235 changes: 116 additions & 119 deletions internal/api/handlers/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ package handlers
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
Expand All @@ -26,41 +27,122 @@ type AuthHandler struct {
cache cache.Store
oauth2Config *oauth2.Config
httpClient *http.Client
userinfoURL string
}

func NewAuthHandler(config *types.AuthConfig, store cache.Store) *AuthHandler {
// Ensure issuer URL doesn't have trailing slash
issuer := strings.TrimRight(config.Issuer, "/")
httpClient := &http.Client{Timeout: 1 * time.Second}
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()

log.Debug().
Str("issuer", config.Issuer).
Msg("initializing auth handler")

// Get provider endpoints through OIDC discovery
endpoints, userinfoURL, err := getProviderEndpoints(ctx, httpClient, config.Issuer)
if err != nil {
log.Error().Err(err).
Msg("OIDC discovery failed. Please ensure your provider supports OpenID Connect discovery as specified in https://openid.net/specs/openid-connect-discovery-1_0.html")
return nil
}

log.Debug().
Str("auth_url", endpoints.AuthURL).
Str("token_url", endpoints.TokenURL).
Msg("using discovered endpoints")

oauth2Config := &oauth2.Config{
ClientID: config.ClientID,
ClientSecret: config.ClientSecret,
RedirectURL: config.RedirectURL,
Endpoint: oauth2.Endpoint{
AuthURL: fmt.Sprintf("%s/authorize", issuer),
TokenURL: fmt.Sprintf("%s/oauth/token", issuer),
},
Scopes: []string{"openid", "profile", "email"},
Endpoint: endpoints,
Scopes: []string{"openid", "profile", "email"},
}

return &AuthHandler{
config: config,
cache: store,
oauth2Config: oauth2Config,
httpClient: &http.Client{Timeout: 10 * time.Second},
httpClient: httpClient,
userinfoURL: userinfoURL,
}
}

type providerConfig struct {
Issuer string `json:"issuer"`
AuthorizationEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"`
UserinfoEndpoint string `json:"userinfo_endpoint"`
}

// getProviderEndpoints fetches provider configuration and returns oauth2.Endpoint
// Examples:
// Simple issuer (Google):
//
// Input: https://accounts.google.com
// Result: https://accounts.google.com/.well-known/openid-configuration
//
// Path-based (e.g. Keycloak realm):
//
// Input: https://auth.example.com/realms/myrealm
// Result: https://auth.example.com/realms/myrealm/.well-known/openid-configuration
func getProviderEndpoints(ctx context.Context, client *http.Client, issuer string) (oauth2.Endpoint, string, error) {
issuer = strings.TrimRight(issuer, "/")

// Construct well-known URL according to spec
var wellKnown string
if strings.Contains(issuer, "/.well-known/openid-configuration") {
wellKnown = issuer
} else {
wellKnown = issuer + "/.well-known/openid-configuration"
}

req, err := http.NewRequestWithContext(ctx, "GET", wellKnown, nil)
if err != nil {
return oauth2.Endpoint{}, "", fmt.Errorf("creating discovery request: %w", err)
}

resp, err := client.Do(req)
if err != nil {
return oauth2.Endpoint{}, "", fmt.Errorf("fetching discovery document: %w", err)
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
return oauth2.Endpoint{}, "", fmt.Errorf("reading discovery document: %w", err)
}

log.Debug().
Str("issuer", issuer).
Str("well_known_url", wellKnown).
Msg("OIDC discovery successful")

var discovery struct {
AuthURL string `json:"authorization_endpoint"`
TokenURL string `json:"token_endpoint"`
UserinfoURL string `json:"userinfo_endpoint"`
}

if err := json.Unmarshal(body, &discovery); err != nil {
return oauth2.Endpoint{}, "", fmt.Errorf("parsing discovery document: %w", err)
}

return oauth2.Endpoint{
AuthURL: discovery.AuthURL,
TokenURL: discovery.TokenURL,
}, discovery.UserinfoURL, nil
}

// generateSecureRandomString generates a cryptographically secure random string
func generateSecureRandomString(length int) (string, error) {
bytes := make([]byte, length)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(bytes)[:length], nil
return hex.EncodeToString(bytes)[:length], nil
}

// Login initiates the OIDC authentication flow
func (h *AuthHandler) Login(c *gin.Context) {
// Create context with timeout for login flow
ctx, cancel := context.WithTimeout(c.Request.Context(), 5*time.Second)
Expand Down Expand Up @@ -130,7 +212,6 @@ func (h *AuthHandler) Login(c *gin.Context) {
c.Redirect(http.StatusTemporaryRedirect, authURL)
}

// Callback handles the OIDC provider callback
func (h *AuthHandler) Callback(c *gin.Context) {
// Create context with timeout for callback handling
ctx, cancel := context.WithTimeout(c.Request.Context(), 10*time.Second)
Expand Down Expand Up @@ -169,8 +250,10 @@ func (h *AuthHandler) Callback(c *gin.Context) {
return
}

if err := h.cache.Delete(ctx, stateKey); err != nil && err != cache.ErrKeyNotFound {
log.Error().Err(err).Msg("failed to delete state from cache")
if err := h.cache.Delete(ctx, stateKey); err != nil {
if err != cache.ErrKeyNotFound {
log.Error().Err(err).Msg("failed to delete state from cache")
}
}

// Exchange code for token using context
Expand Down Expand Up @@ -233,7 +316,6 @@ func (h *AuthHandler) Callback(c *gin.Context) {
))
}

// Logout handles user logout
func (h *AuthHandler) Logout(c *gin.Context) {
// Create context with timeout for logout
ctx, cancel := context.WithTimeout(c.Request.Context(), 5*time.Second)
Expand Down Expand Up @@ -285,19 +367,18 @@ func (h *AuthHandler) Logout(c *gin.Context) {
c.Redirect(http.StatusTemporaryRedirect, logoutURL)
}

// VerifyToken verifies a JWT token
func (h *AuthHandler) VerifyToken(c *gin.Context) {
// Create context with timeout for token verification
ctx, cancel := context.WithTimeout(c.Request.Context(), 5*time.Second)
defer cancel()

sessionID, err := c.Cookie("session")
if err != nil {
log.Trace().Msg("no session cookie found")
c.JSON(http.StatusUnauthorized, gin.H{"error": "No session found"})
return
}

// Create context with timeout for token verification
ctx, cancel := context.WithTimeout(c.Request.Context(), 5*time.Second)
defer cancel()

sessionKey := fmt.Sprintf("oidc:session:%s", sessionID)
var sessionData types.SessionData
if err := h.cache.Get(ctx, sessionKey, &sessionData); err != nil {
Expand All @@ -320,12 +401,7 @@ func (h *AuthHandler) VerifyToken(c *gin.Context) {
})
}

// RefreshToken handles token refresh
func (h *AuthHandler) RefreshToken(c *gin.Context) {
// Create context with timeout for token refresh
ctx, cancel := context.WithTimeout(c.Request.Context(), 10*time.Second)
defer cancel()

sessionID, err := c.Cookie("session")
if err != nil {
log.Error().Err(err).Msg("no session cookie found")
Expand All @@ -335,43 +411,35 @@ func (h *AuthHandler) RefreshToken(c *gin.Context) {

sessionKey := fmt.Sprintf("oidc:session:%s", sessionID)
var sessionData types.SessionData
if err := h.cache.Get(ctx, sessionKey, &sessionData); err != nil {
if ctx.Err() != nil {
log.Error().Err(ctx.Err()).Msg("Context canceled while getting session")
c.JSON(http.StatusGatewayTimeout, gin.H{"error": "Operation timed out"})
return
}
if err := h.cache.Get(c.Request.Context(), sessionKey, &sessionData); err != nil {
if err == cache.ErrKeyNotFound {
log.Debug().Msg("session not found or expired")
} else {
log.Error().Err(err).Msg("failed to get session from cache")
}
c.JSON(http.StatusUnauthorized, gin.H{"error": "Session not found"})
c.JSON(http.StatusUnauthorized, gin.H{"error": "Session expired"})
return
}

token := &oauth2.Token{
AccessToken: sessionData.AccessToken,
TokenType: "Bearer",
RefreshToken: sessionData.RefreshToken,
Expiry: sessionData.ExpiresAt,
}

// Create token source with context
tokenSource := h.oauth2Config.TokenSource(ctx, token)
tokenSource := h.oauth2Config.TokenSource(c.Request.Context(), token)

// Refresh the token
newToken, err := tokenSource.Token()
if err != nil {
if ctx.Err() != nil {
log.Error().Err(ctx.Err()).Msg("Context canceled during token refresh")
c.JSON(http.StatusGatewayTimeout, gin.H{"error": "Operation timed out"})
return
}
log.Error().Err(err).Msg("token refresh failed")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to refresh token"})
return
}

// Update session with new token data
// Update session data with new token
sessionData.AccessToken = newToken.AccessToken
sessionData.RefreshToken = newToken.RefreshToken
sessionData.ExpiresAt = newToken.Expiry
Expand All @@ -380,29 +448,12 @@ func (h *AuthHandler) RefreshToken(c *gin.Context) {
}

// Store updated session
if err := h.cache.Set(ctx, sessionKey, sessionData, time.Until(newToken.Expiry)); err != nil {
if ctx.Err() != nil {
log.Error().Err(ctx.Err()).Msg("Context canceled while updating session")
c.JSON(http.StatusGatewayTimeout, gin.H{"error": "Operation timed out"})
return
}
if err := h.cache.Set(c.Request.Context(), sessionKey, sessionData, time.Until(newToken.Expiry)); err != nil {
log.Error().Err(err).Msg("failed to update session in cache")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update session"})
return
}

var isSecure = c.GetHeader("X-Forwarded-Proto") == "https"

c.SetCookie(
"session",
newToken.AccessToken,
int(time.Until(newToken.Expiry).Seconds()),
"/",
"",
isSecure,
true,
)

c.JSON(http.StatusOK, gin.H{
"access_token": newToken.AccessToken,
"token_type": newToken.TokenType,
Expand All @@ -411,77 +462,23 @@ func (h *AuthHandler) RefreshToken(c *gin.Context) {
})
}

// UserInfo returns the current user's information
func (h *AuthHandler) UserInfo(c *gin.Context) {
// Create context with timeout for userinfo request
ctx, cancel := context.WithTimeout(c.Request.Context(), 10*time.Second)
defer cancel()

sessionID, err := c.Cookie("session")
if err != nil {
log.Error().Err(err).Msg("no session cookie found")
c.JSON(http.StatusUnauthorized, gin.H{"error": "No session found"})
return
}

sessionKey := fmt.Sprintf("oidc:session:%s", sessionID)
var sessionData types.SessionData
if err := h.cache.Get(ctx, sessionKey, &sessionData); err != nil {
if ctx.Err() != nil {
log.Error().Err(ctx.Err()).Msg("Context canceled while getting session")
c.JSON(http.StatusGatewayTimeout, gin.H{"error": "Operation timed out"})
return
}
if err == cache.ErrKeyNotFound {
log.Debug().Msg("session not found or expired")
} else {
log.Error().Err(err).Msg("failed to get session from cache")
}
c.JSON(http.StatusUnauthorized, gin.H{"error": "Session not found"})
if err := h.cache.Get(c.Request.Context(), sessionKey, &sessionData); err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid session"})
return
}

userinfoURL := fmt.Sprintf("%s/userinfo", strings.TrimRight(h.config.Issuer, "/"))
req, err := http.NewRequestWithContext(ctx, "GET", userinfoURL, nil)
if err != nil {
log.Error().Err(err).Msg("failed to create userinfo request")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Internal server error"})
return
}

req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", sessionData.AccessToken))

resp, err := h.httpClient.Do(req)
if err != nil {
if ctx.Err() != nil {
log.Error().Err(ctx.Err()).Msg("Context canceled during userinfo request")
c.JSON(http.StatusGatewayTimeout, gin.H{"error": "Operation timed out"})
return
}
log.Error().Err(err).Msg("userinfo request failed")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user info"})
return
}

if resp == nil {
log.Error().Msg("received nil response from userinfo endpoint")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user info"})
return
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
log.Error().Int("status", resp.StatusCode).Msg("userinfo request returned non-200 status")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user info"})
return
}

var userInfo map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
log.Error().Err(err).Msg("failed to decode userinfo response")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to process user info"})
return
}

c.JSON(http.StatusOK, userInfo)
// Just return the basic session info we already have
c.JSON(http.StatusOK, gin.H{
"user_id": sessionData.UserID,
"auth_type": sessionData.AuthType,
})
}
Loading

0 comments on commit 051ccaa

Please sign in to comment.