Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PKCE support to upstream (167b6c4) #116

Open
wants to merge 6 commits into
base: ow/4bc9904b3cd0726d3f9c3cbaeade972cf167b6c4
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backend/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ Pipfile
!/data
/data/*
/open_webui/data/*
/open-webui-pipelines/pipelines/*/valves.json
.webui_secret_key
21 changes: 21 additions & 0 deletions backend/open_webui/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,23 @@ def __getattr__(self, key):
[role.strip() for role in os.environ.get("OAUTH_ADMIN_ROLES", "admin").split(",")],
)

OAUTH_ACR_CLAIM = PersistentConfig(
"OAUTH_ACR_CLAIM",
"oauth.oidc.acr_claim",
os.environ.get("OAUTH_ACR_CLAIM", ""),
)
OAUTH_NONCE_CLAIM = PersistentConfig(
"OAUTH_NONCE_CLAIM",
"oauth.oidc.nonce_claim",
os.environ.get("OAUTH_NONCE_CLAIM", ""),
)

OAUTH_USE_PKCE = PersistentConfig(
"OAUTH_USE_PKCE",
"oauth.oidc.use_pkce",
os.environ.get("OAUTH_USE_PKCE", ""),
)


def load_oauth_providers():
OAUTH_PROVIDERS.clear()
Expand Down Expand Up @@ -460,6 +477,10 @@ def load_oauth_providers():
"redirect_uri": OPENID_REDIRECT_URI.value,
}

# TODO: does this work out of the box for google and microsoft, too?
if OAUTH_USE_PKCE.value:
OAUTH_PROVIDERS["oidc"]["pkce"] = OAUTH_USE_PKCE.value


load_oauth_providers()

Expand Down
29 changes: 25 additions & 4 deletions backend/open_webui/utils/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import aiohttp
from authlib.integrations.starlette_client import OAuth
from authlib.common.security import generate_token
from authlib.oidc.core import UserInfo
from fastapi import (
HTTPException,
Expand All @@ -26,6 +27,8 @@
OAUTH_USERNAME_CLAIM,
OAUTH_ALLOWED_ROLES,
OAUTH_ADMIN_ROLES,
OAUTH_ACR_CLAIM,
OAUTH_NONCE_CLAIM,
WEBHOOK_URL,
JWT_EXPIRES_IN,
AppConfig,
Expand All @@ -49,6 +52,8 @@
auth_manager_config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
auth_manager_config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES
auth_manager_config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES
auth_manager_config.OAUTH_ACR_CLAIM = OAUTH_ACR_CLAIM
auth_manager_config.OAUTH_NONCE_CLAIM = OAUTH_NONCE_CLAIM
auth_manager_config.WEBHOOK_URL = WEBHOOK_URL
auth_manager_config.JWT_EXPIRES_IN = JWT_EXPIRES_IN

Expand All @@ -57,14 +62,17 @@ class OAuthManager:
def __init__(self):
self.oauth = OAuth()
for provider_name, provider_config in OAUTH_PROVIDERS.items():
client_kwargs = {"scope": provider_config["scope"]}
client_kwargs.update(
{"code_challenge_method": "S256"} if provider_config["pkce"] else {}
)

self.oauth.register(
name=provider_name,
client_id=provider_config["client_id"],
client_secret=provider_config["client_secret"],
server_metadata_url=provider_config["server_metadata_url"],
client_kwargs={
"scope": provider_config["scope"],
},
client_kwargs=client_kwargs,
redirect_uri=provider_config["redirect_uri"],
)

Expand Down Expand Up @@ -127,7 +135,20 @@ async def handle_login(self, provider, request):
client = self.get_client(provider)
if client is None:
raise HTTPException(404)
return await client.authorize_redirect(request, redirect_uri)

redirect_kwargs = {}
if auth_manager_config.OAUTH_ACR_CLAIM:
redirect_kwargs.update({"acr_values": auth_manager_config.OAUTH_ACR_CLAIM})
if auth_manager_config.OAUTH_NONCE_CLAIM:
redirect_kwargs.update(
{"nonce": generate_token(int(auth_manager_config.OAUTH_NONCE_CLAIM))}
)

return await client.authorize_redirect(
request,
redirect_uri=redirect_uri,
**redirect_kwargs,
)

async def handle_callback(self, provider, request, response):
if provider not in OAUTH_PROVIDERS:
Expand Down
125 changes: 125 additions & 0 deletions backend/open_webui/utils/test_oauth_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import pytest
import logging
from fastapi import Request
from fastapi.datastructures import URL
from starlette.datastructures import Headers, QueryParams
from starlette.responses import RedirectResponse
from unittest.mock import patch, AsyncMock, MagicMock
from urllib.parse import urlparse, parse_qs
import json

from authlib.integrations.starlette_client import OAuth
from open_webui.utils.oauth import OAuthManager, OAUTH_PROVIDERS

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


@pytest.fixture
def mock_request():
async def receive():
return {"type": "http.request"}

headers = Headers({"host": "testserver"})
query_params = QueryParams({})
url = URL("http://testserver/test/path")
request = AsyncMock(
spec=Request,
receive=receive,
headers=headers,
query_params=query_params,
url=url,
)
request.url_for = AsyncMock(return_value=url)
return request


@pytest.mark.asyncio
async def test_handle_login(mock_request):
# Set up a test OAuth provider configuration
test_provider_config = {
"oidc": {
"client_id": "test_client_id",
"client_secret": "test_client_secret",
"server_metadata_url": "https://uaa.fr.cloud.gov/.well-known/openid-configuration",
"scope": "openid profile email",
"redirect_uri": "http://localhost/oauth/oidc/callback",
}
}

# Mock the server metadata response
mock_metadata = {
"issuer": "https://uaa.fr.cloud.gov",
"authorization_endpoint": "https://uaa.fr.cloud.gov/oauth/authorize",
"token_endpoint": "https://uaa.fr.cloud.gov/oauth/token",
"userinfo_endpoint": "https://uaa.fr.cloud.gov/userinfo",
"jwks_uri": "https://uaa.fr.cloud.gov/token_keys",
"scopes_supported": ["openid", "profile", "email"],
"response_types_supported": ["code", "token", "id_token"],
"grant_types_supported": ["authorization_code", "implicit"],
}

with patch.dict(OAUTH_PROVIDERS, test_provider_config):
oauth_manager = OAuthManager()

# Patch the OAuth.create_client method to return a client that we can inspect
with (
patch.object(OAuth, "create_client") as mock_create_client,
patch("aiohttp.ClientSession.get") as mock_get,
):

# Mock the HTTP request to the metadata URL
mock_response = AsyncMock()
mock_response.status = 200
mock_response.json.return_value = mock_metadata
mock_get.return_value.__aenter__.return_value = mock_response

mock_client = AsyncMock()
mock_create_client.return_value = mock_client

# Set up the mock client's authorize_redirect method
async def mock_authorize_redirect(request, redirect_uri):
auth_url = mock_metadata["authorization_endpoint"]
params = {
"response_type": "code",
"client_id": "test_client_id",
"redirect_uri": redirect_uri,
"scope": "openid profile email",
"state": "some_state_value",
}
query_string = "&".join(f"{k}={v}" for k, v in params.items())
full_url = f"{auth_url}?{query_string}"
return RedirectResponse(url=full_url)

mock_client.authorize_redirect = mock_authorize_redirect

# Call the handle_login method
redirect_response = await oauth_manager.handle_login("oidc", mock_request)

# Log the redirect response
logger.info("\n\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n")
logger.info(f"Redirect location: {redirect_response.headers['location']}")
logger.info("\n\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n")

# Assert that we got a RedirectResponse
assert isinstance(redirect_response, RedirectResponse)
assert redirect_response.status_code == 307 # Temporary Redirect

# Parse the URL from the response
parsed_url = urlparse(redirect_response.headers["location"])
query_params = parse_qs(parsed_url.query)

# Assert on the components of the URL
assert parsed_url.scheme == "https"
assert parsed_url.netloc == "uaa.fr.cloud.gov"
assert parsed_url.path == "/oauth/authorize"

# Assert on the query parameters
assert query_params["response_type"] == ["code"]
assert query_params["client_id"] == ["test_client_id"]
assert query_params["redirect_uri"] == [
"http://localhost/oauth/oidc/callback"
]
assert query_params["scope"] == ["openid profile email"]
assert "state" in query_params
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,4 @@
"node": ">=18.13.0 <=22.x.x",
"npm": ">=6.0.0"
}
}
}