Skip to content

Commit

Permalink
feat: purge OAuth2 tokens when DB changes (apache#31164)
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida authored Nov 26, 2024
1 parent f077323 commit 68499a1
Show file tree
Hide file tree
Showing 5 changed files with 186 additions and 4 deletions.
32 changes: 32 additions & 0 deletions superset/commands/database/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ def run(self) -> Model:
# since they're name based
original_database_name = self._model.database_name

# Depending on the changes to the OAuth2 configuration we may need to purge
# existing personal tokens.
self._handle_oauth2()

database = DatabaseDAO.update(self._model, self._properties)
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
ssh_tunnel = self._handle_ssh_tunnel(database)
Expand All @@ -88,6 +92,34 @@ def run(self) -> Model:

return database

def _handle_oauth2(self) -> None:
"""
Handle changes in OAuth2.
"""
if not self._model:
return

current_config = self._model.get_oauth2_config()
if not current_config:
return

new_config = self._properties["encrypted_extra"].get("oauth2_client_info", {})

# Keys that require purging personal tokens because they probably are no longer
# valid. For example, if the scope has changed the existing tokens are still
# associated with the old scope. Similarly, if the endpoints changed the tokens
# are probably no longer valid.
keys = {
"id",
"scope",
"authorization_request_uri",
"token_request_uri",
}
for key in keys:
if current_config.get(key) != new_config.get(key):
self._model.purge_oauth2_tokens()
break

def _handle_ssh_tunnel(self, database: Database) -> SSHTunnel | None:
"""
Delete, create, or update an SSH tunnel.
Expand Down
12 changes: 9 additions & 3 deletions superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1543,9 +1543,12 @@ def EMAIL_HEADER_MUTATOR( # pylint: disable=invalid-name,unused-argument
# one here.
TEST_DATABASE_CONNECTION_TIMEOUT = timedelta(seconds=30)

# Details needed for databases that allows user to authenticate using personal
# OAuth2 tokens. See https://github.com/apache/superset/issues/20300 for more
# information. The scope and URIs are optional.
# Details needed for databases that allows user to authenticate using personal OAuth2
# tokens. See https://github.com/apache/superset/issues/20300 for more information. The
# scope and URIs are usually optional.
# NOTE that if you change the id, scope, or URIs in this file, you probably need to purge
# the existing tokens from the database. This needs to be done by running a query to
# delete the existing tokens.
DATABASE_OAUTH2_CLIENTS: dict[str, dict[str, Any]] = {
# "Google Sheets": {
# "id": "XXX.apps.googleusercontent.com",
Expand All @@ -1561,14 +1564,17 @@ def EMAIL_HEADER_MUTATOR( # pylint: disable=invalid-name,unused-argument
# "token_request_uri": "https://oauth2.googleapis.com/token",
# },
}

# OAuth2 state is encoded in a JWT using the alogorithm below.
DATABASE_OAUTH2_JWT_ALGORITHM = "HS256"

# By default the redirect URI points to /api/v1/database/oauth2/ and doesn't have to be
# specified. If you're running multiple Superset instances you might want to have a
# proxy handling the redirects, since redirect URIs need to be registered in the OAuth2
# applications. In that case, the proxy can forward the request to the correct instance
# by looking at the `default_redirect_uri` attribute in the OAuth2 state object.
# DATABASE_OAUTH2_REDIRECT_URI = "http://localhost:8088/api/v1/database/oauth2/"

# Timeout when fetching access and refresh tokens.
DATABASE_OAUTH2_TIMEOUT = timedelta(seconds=30)

Expand Down
14 changes: 13 additions & 1 deletion superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
from sqlalchemy.schema import UniqueConstraint
from sqlalchemy.sql import ColumnElement, expression, Select

from superset import app, db_engine_specs, is_feature_enabled
from superset import app, db, db_engine_specs, is_feature_enabled
from superset.commands.database.exceptions import DatabaseInvalidError
from superset.constants import LRU_CACHE_MAX_SIZE, PASSWORD_MASK
from superset.databases.utils import make_url_safe
Expand Down Expand Up @@ -1136,6 +1136,18 @@ def start_oauth2_dance(self) -> None:
"""
return self.db_engine_spec.start_oauth2_dance(self)

def purge_oauth2_tokens(self) -> None:
"""
Delete all OAuth2 tokens associated with this database.
This is needed when the configuration changes. For example, a new client ID and
secret probably will require new tokens. The same is valid for changes in the
scope or in the endpoints.
"""
db.session.query(DatabaseUserOAuth2Tokens).filter(
DatabaseUserOAuth2Tokens.id == self.id
).delete()


sqla.event.listen(Database, "after_insert", security_manager.database_after_insert)
sqla.event.listen(Database, "after_update", security_manager.database_after_update)
Expand Down
52 changes: 52 additions & 0 deletions tests/unit_tests/commands/databases/update_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@
from superset.exceptions import OAuth2RedirectError
from superset.extensions import security_manager

oauth2_client_info = {
"id": "client_id",
"secret": "client_secret",
"scope": "scope-a",
"redirect_uri": "redirect_uri",
"authorization_request_uri": "auth_uri",
"token_request_uri": "token_uri",
"request_content_type": "json",
}


@pytest.fixture()
def database_with_catalog(mocker: MockerFixture) -> MagicMock:
Expand Down Expand Up @@ -72,6 +82,7 @@ def database_needs_oauth2(mocker: MockerFixture) -> MagicMock:
"tab_id",
"redirect_uri",
)
database.get_oauth2_config.return_value = oauth2_client_info

return database

Expand Down Expand Up @@ -321,6 +332,47 @@ def test_update_with_oauth2(
"add_permission_view_menu",
)

database_needs_oauth2.db_engine_spec.unmask_encrypted_extra.return_value = {
"oauth2_client_info": oauth2_client_info,
}

UpdateDatabaseCommand(1, {}).run()

add_permission_view_menu.assert_not_called()
database_needs_oauth2.purge_oauth2_tokens.assert_not_called()


def test_update_with_oauth2_changed(
mocker: MockerFixture,
database_needs_oauth2: MockerFixture,
) -> None:
"""
Test that the database can be updated even if OAuth2 is needed to connect.
"""
DatabaseDAO = mocker.patch("superset.commands.database.update.DatabaseDAO")
DatabaseDAO.find_by_id.return_value = database_needs_oauth2
DatabaseDAO.update.return_value = database_needs_oauth2

find_permission_view_menu = mocker.patch.object(
security_manager,
"find_permission_view_menu",
)
find_permission_view_menu.side_effect = [
None, # schema1 has no permissions
"[my_db].[schema2]", # second schema already exists
]
add_permission_view_menu = mocker.patch.object(
security_manager,
"add_permission_view_menu",
)

modified_oauth2_client_info = oauth2_client_info.copy()
modified_oauth2_client_info["scope"] = "scope-b"
database_needs_oauth2.db_engine_spec.unmask_encrypted_extra.return_value = {
"oauth2_client_info": modified_oauth2_client_info,
}

UpdateDatabaseCommand(1, {}).run()

add_permission_view_menu.assert_not_called()
database_needs_oauth2.purge_oauth2_tokens.assert_called()
80 changes: 80 additions & 0 deletions tests/unit_tests/models/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pytest_mock import MockerFixture
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import make_url
from sqlalchemy.orm.session import Session

from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.errors import SupersetErrorType
Expand Down Expand Up @@ -603,3 +604,82 @@ def test_engine_context_manager(mocker: MockerFixture) -> None:
source=None,
sqlalchemy_uri="trino://",
)


def test_purge_oauth2_tokens(session: Session) -> None:
"""
Test the `purge_oauth2_tokens` method.
"""
from flask_appbuilder.security.sqla.models import Role, User # noqa: F401

from superset.models.core import Database, DatabaseUserOAuth2Tokens

Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member

user = User(
first_name="Alice",
last_name="Doe",
email="[email protected]",
username="adoe",
)
session.add(user)
session.flush()

database1 = Database(database_name="my_oauth2_db", sqlalchemy_uri="sqlite://")
database2 = Database(database_name="my_other_oauth2_db", sqlalchemy_uri="sqlite://")
session.add_all([database1, database2])
session.flush()

tokens = [
DatabaseUserOAuth2Tokens(
user_id=user.id,
database_id=database1.id,
access_token="my_access_token",
access_token_expiration=datetime(2023, 1, 1),
refresh_token="my_refresh_token",
),
DatabaseUserOAuth2Tokens(
user_id=user.id,
database_id=database2.id,
access_token="my_other_access_token",
access_token_expiration=datetime(2024, 1, 1),
refresh_token="my_other_refresh_token",
),
]
session.add_all(tokens)
session.flush()

assert len(session.query(DatabaseUserOAuth2Tokens).all()) == 2

token = (
session.query(DatabaseUserOAuth2Tokens)
.filter_by(database_id=database1.id)
.one()
)
assert token.user_id == user.id
assert token.database_id == database1.id
assert token.access_token == "my_access_token"
assert token.access_token_expiration == datetime(2023, 1, 1)
assert token.refresh_token == "my_refresh_token"

database1.purge_oauth2_tokens()

# confirm token was deleted
token = (
session.query(DatabaseUserOAuth2Tokens)
.filter_by(database_id=database1.id)
.one_or_none()
)
assert token is None

# make sure other DB tokens weren't deleted
token = (
session.query(DatabaseUserOAuth2Tokens)
.filter_by(database_id=database2.id)
.one()
)
assert token is not None

# make sure database was not deleted... just in case
database = session.query(Database).filter_by(id=database1.id).one()
assert database.name == "my_oauth2_db"

0 comments on commit 68499a1

Please sign in to comment.