From fcd0a35aed20189a52bc2e1b94f217d778216f0b Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Fri, 24 Jan 2025 13:13:52 +0900 Subject: [PATCH] Fix #1636 Custom Values passed into correctly into Bot/Installation class when cloned during token rotation (#1638) --- .../oauth/installation_store/models/bot.py | 11 +++++--- .../installation_store/models/installation.py | 11 +++++--- .../oauth/token_rotation/async_rotator.py | 6 ++--- slack_sdk/oauth/token_rotation/rotator.py | 6 ++--- .../oauth/installation_store/test_models.py | 4 +++ .../token_rotation/test_token_rotator.py | 25 ++++++++++++++++++ .../token_rotation/test_token_rotator.py | 26 +++++++++++++++++++ 7 files changed, 77 insertions(+), 12 deletions(-) diff --git a/slack_sdk/oauth/installation_store/models/bot.py b/slack_sdk/oauth/installation_store/models/bot.py index 22f6dd10f..52c1dac50 100644 --- a/slack_sdk/oauth/installation_store/models/bot.py +++ b/slack_sdk/oauth/installation_store/models/bot.py @@ -87,8 +87,8 @@ def set_custom_value(self, name: str, value: Any): def get_custom_value(self, name: str) -> Optional[Any]: return self.custom_values.get(name) - def to_dict(self) -> Dict[str, Any]: - standard_values = { + def _to_standard_value_dict(self) -> Dict[str, Any]: + return { "app_id": self.app_id, "enterprise_id": self.enterprise_id, "enterprise_name": self.enterprise_name, @@ -105,6 +105,11 @@ def to_dict(self) -> Dict[str, Any]: "is_enterprise_install": self.is_enterprise_install, "installed_at": datetime.utcfromtimestamp(self.installed_at), } + + def to_dict_for_copying(self) -> Dict[str, Any]: + return {"custom_values": self.custom_values, **self._to_standard_value_dict()} + + def to_dict(self) -> Dict[str, Any]: # prioritize standard_values over custom_values # when the same keys exist in both - return {**self.custom_values, **standard_values} + return {**self.custom_values, **self._to_standard_value_dict()} diff --git a/slack_sdk/oauth/installation_store/models/installation.py b/slack_sdk/oauth/installation_store/models/installation.py index 42d4c90bd..91c6510f2 100644 --- a/slack_sdk/oauth/installation_store/models/installation.py +++ b/slack_sdk/oauth/installation_store/models/installation.py @@ -159,8 +159,8 @@ def set_custom_value(self, name: str, value: Any): def get_custom_value(self, name: str) -> Optional[Any]: return self.custom_values.get(name) - def to_dict(self) -> Dict[str, Any]: - standard_values = { + def _to_standard_value_dict(self) -> Dict[str, Any]: + return { "app_id": self.app_id, "enterprise_id": self.enterprise_id, "enterprise_name": self.enterprise_name, @@ -190,6 +190,11 @@ def to_dict(self) -> Dict[str, Any]: "token_type": self.token_type, "installed_at": datetime.utcfromtimestamp(self.installed_at), } + + def to_dict_for_copying(self) -> Dict[str, Any]: + return {"custom_values": self.custom_values, **self._to_standard_value_dict()} + + def to_dict(self) -> Dict[str, Any]: # prioritize standard_values over custom_values # when the same keys exist in both - return {**self.custom_values, **standard_values} + return {**self.custom_values, **self._to_standard_value_dict()} diff --git a/slack_sdk/oauth/token_rotation/async_rotator.py b/slack_sdk/oauth/token_rotation/async_rotator.py index 1b4047bf1..c3506f004 100644 --- a/slack_sdk/oauth/token_rotation/async_rotator.py +++ b/slack_sdk/oauth/token_rotation/async_rotator.py @@ -54,7 +54,7 @@ async def perform_token_rotation( if rotated_bot is not None: if rotated_installation is None: - rotated_installation = Installation(**installation.to_dict()) + rotated_installation = Installation(**installation.to_dict_for_copying()) rotated_installation.bot_token = rotated_bot.bot_token rotated_installation.bot_refresh_token = rotated_bot.bot_refresh_token rotated_installation.bot_token_expires_at = rotated_bot.bot_token_expires_at @@ -93,7 +93,7 @@ async def perform_bot_token_rotation( if refresh_response.get("token_type") != "bot": return None - refreshed_bot = Bot(**bot.to_dict()) + refreshed_bot = Bot(**bot.to_dict_for_copying()) refreshed_bot.bot_token = refresh_response["access_token"] refreshed_bot.bot_refresh_token = refresh_response.get("refresh_token") refreshed_bot.bot_token_expires_at = int(time()) + int(refresh_response["expires_in"]) @@ -132,7 +132,7 @@ async def perform_user_token_rotation( if refresh_response.get("token_type") != "user": return None - refreshed_installation = Installation(**installation.to_dict()) + refreshed_installation = Installation(**installation.to_dict_for_copying()) refreshed_installation.user_token = refresh_response.get("access_token") refreshed_installation.user_refresh_token = refresh_response.get("refresh_token") refreshed_installation.user_token_expires_at = int(time()) + int(refresh_response.get("expires_in")) # type: ignore[arg-type] # noqa: E501 diff --git a/slack_sdk/oauth/token_rotation/rotator.py b/slack_sdk/oauth/token_rotation/rotator.py index aa5e42916..e7dab22cc 100644 --- a/slack_sdk/oauth/token_rotation/rotator.py +++ b/slack_sdk/oauth/token_rotation/rotator.py @@ -48,7 +48,7 @@ def perform_token_rotation( if rotated_bot is not None: if rotated_installation is None: - rotated_installation = Installation(**installation.to_dict()) + rotated_installation = Installation(**installation.to_dict_for_copying()) rotated_installation.bot_token = rotated_bot.bot_token rotated_installation.bot_refresh_token = rotated_bot.bot_refresh_token rotated_installation.bot_token_expires_at = rotated_bot.bot_token_expires_at @@ -85,7 +85,7 @@ def perform_bot_token_rotation( if refresh_response.get("token_type") != "bot": return None - refreshed_bot = Bot(**bot.to_dict()) + refreshed_bot = Bot(**bot.to_dict_for_copying()) refreshed_bot.bot_token = refresh_response["access_token"] refreshed_bot.bot_refresh_token = refresh_response.get("refresh_token") refreshed_bot.bot_token_expires_at = int(time()) + int(refresh_response["expires_in"]) @@ -125,7 +125,7 @@ def perform_user_token_rotation( if refresh_response.get("token_type") != "user": return None - refreshed_installation = Installation(**installation.to_dict()) + refreshed_installation = Installation(**installation.to_dict_for_copying()) refreshed_installation.user_token = refresh_response.get("access_token") refreshed_installation.user_refresh_token = refresh_response.get("refresh_token") refreshed_installation.user_token_expires_at = int(time()) + int(refresh_response["expires_in"]) diff --git a/tests/slack_sdk/oauth/installation_store/test_models.py b/tests/slack_sdk/oauth/installation_store/test_models.py index c39e4aa03..d63964be6 100644 --- a/tests/slack_sdk/oauth/installation_store/test_models.py +++ b/tests/slack_sdk/oauth/installation_store/test_models.py @@ -20,6 +20,7 @@ def test_bot(self): ) self.assertIsNotNone(bot) self.assertIsNotNone(bot.to_dict()) + self.assertIsNotNone(bot.to_dict_for_copying()) def test_bot_custom_fields(self): bot = Bot( @@ -33,6 +34,7 @@ def test_bot_custom_fields(self): bot.set_custom_value("app_id", "A222") self.assertEqual(bot.get_custom_value("service_user_id"), "XYZ123") self.assertEqual(bot.to_dict().get("service_user_id"), "XYZ123") + self.assertEqual(bot.to_dict_for_copying().get("custom_values").get("service_user_id"), "XYZ123") def test_installation(self): installation = Installation( @@ -73,6 +75,7 @@ def test_installation_custom_fields(self): self.assertEqual(installation.get_custom_value("service_user_id"), "XYZ123") self.assertEqual(installation.to_dict().get("service_user_id"), "XYZ123") self.assertEqual(installation.to_dict().get("app_id"), "A111") + self.assertEqual(installation.to_dict_for_copying().get("custom_values").get("app_id"), "A222") bot = installation.to_bot() self.assertEqual(bot.app_id, "A111") @@ -80,3 +83,4 @@ def test_installation_custom_fields(self): self.assertEqual(bot.to_dict().get("app_id"), "A111") self.assertEqual(bot.to_dict().get("service_user_id"), "XYZ123") + self.assertEqual(bot.to_dict_for_copying().get("custom_values").get("app_id"), "A222") diff --git a/tests/slack_sdk/oauth/token_rotation/test_token_rotator.py b/tests/slack_sdk/oauth/token_rotation/test_token_rotator.py index 9926932c2..d599e3af0 100644 --- a/tests/slack_sdk/oauth/token_rotation/test_token_rotator.py +++ b/tests/slack_sdk/oauth/token_rotation/test_token_rotator.py @@ -43,6 +43,31 @@ def test_refresh(self): ) self.assertIsNone(should_not_be_refreshed) + def test_refresh_with_custom_values(self): + installation = Installation( + app_id="A111", + enterprise_id="E111", + team_id="T111", + user_id="U111", + bot_id="B111", + bot_token="xoxe.xoxp-1-initial", + bot_scopes=["chat:write"], + bot_user_id="U222", + bot_refresh_token="xoxe-1-initial", + bot_token_expires_in=43200, + custom_values={"foo": "bar"}, + ) + refreshed = self.token_rotator.perform_token_rotation( + installation=installation, minutes_before_expiration=60 * 24 * 365 + ) + self.assertIsNotNone(refreshed) + self.assertIsNotNone(refreshed.custom_values) + + should_not_be_refreshed = self.token_rotator.perform_token_rotation( + installation=installation, minutes_before_expiration=1 + ) + self.assertIsNone(should_not_be_refreshed) + def test_token_rotation_disabled(self): installation = Installation( app_id="A111", diff --git a/tests/slack_sdk_async/oauth/token_rotation/test_token_rotator.py b/tests/slack_sdk_async/oauth/token_rotation/test_token_rotator.py index edc957647..4c16bb3b5 100644 --- a/tests/slack_sdk_async/oauth/token_rotation/test_token_rotator.py +++ b/tests/slack_sdk_async/oauth/token_rotation/test_token_rotator.py @@ -45,6 +45,32 @@ async def test_refresh(self): ) self.assertIsNone(should_not_be_refreshed) + @async_test + async def test_refresh_with_custom_values(self): + installation = Installation( + app_id="A111", + enterprise_id="E111", + team_id="T111", + user_id="U111", + bot_id="B111", + bot_token="xoxe.xoxp-1-initial", + bot_scopes=["chat:write"], + bot_user_id="U222", + bot_refresh_token="xoxe-1-initial", + bot_token_expires_in=43200, + custom_values={"foo": "bar"}, + ) + refreshed = await self.token_rotator.perform_token_rotation( + installation=installation, minutes_before_expiration=60 * 24 * 365 + ) + self.assertIsNotNone(refreshed) + self.assertIsNotNone(refreshed.custom_values) + + should_not_be_refreshed = await self.token_rotator.perform_token_rotation( + installation=installation, minutes_before_expiration=1 + ) + self.assertIsNone(should_not_be_refreshed) + @async_test async def test_token_rotation_disabled(self): installation = Installation(