diff --git a/cluster_agent/identity/slurm_user/constants.py b/cluster_agent/identity/slurm_user/constants.py index a152bfe..d85a332 100644 --- a/cluster_agent/identity/slurm_user/constants.py +++ b/cluster_agent/identity/slurm_user/constants.py @@ -12,6 +12,7 @@ class MapperType(str, Enum): LDAP = "LDAP" SINGLE_USER = "SINGLE_USER" + AZURE_AD = "AZURE_AD" class LDAPAuthType(str, Enum): diff --git a/cluster_agent/identity/slurm_user/exceptions.py b/cluster_agent/identity/slurm_user/exceptions.py index d168370..c19e678 100644 --- a/cluster_agent/identity/slurm_user/exceptions.py +++ b/cluster_agent/identity/slurm_user/exceptions.py @@ -15,3 +15,7 @@ class LDAPError(ClusterAgentError): class SingleUserError(ClusterAgentError): """Raise exception when there is a problem with single-user submission.""" + + +class AzureADError(ClusterAgentError): + """Raise exception when communication with Azure AD fails.""" diff --git a/cluster_agent/identity/slurm_user/factory.py b/cluster_agent/identity/slurm_user/factory.py index 18e658a..18e4a1f 100644 --- a/cluster_agent/identity/slurm_user/factory.py +++ b/cluster_agent/identity/slurm_user/factory.py @@ -8,17 +8,19 @@ from cluster_agent.identity.slurm_user.mappers import ( SlurmUserMapper, LDAPMapper, + AzureADMapper, SingleUserMapper, ) mapper_map = { MapperType.LDAP: LDAPMapper, + MapperType.AZURE_AD: AzureADMapper, MapperType.SINGLE_USER: SingleUserMapper, } -def manufacture() -> SlurmUserMapper: +async def manufacture() -> SlurmUserMapper: """ Create an instance of a Slurm user mapper given the app configuration. @@ -32,5 +34,5 @@ def manufacture() -> SlurmUserMapper: ) assert mapper_class is not None mapper_instance = mapper_class() - mapper_instance.configure(SLURM_USER_SETTINGS) + await mapper_instance.configure(SLURM_USER_SETTINGS) return mapper_instance diff --git a/cluster_agent/identity/slurm_user/mappers/__init__.py b/cluster_agent/identity/slurm_user/mappers/__init__.py index 9d6cf39..ee9fcaa 100644 --- a/cluster_agent/identity/slurm_user/mappers/__init__.py +++ b/cluster_agent/identity/slurm_user/mappers/__init__.py @@ -1,3 +1,4 @@ +from cluster_agent.identity.slurm_user.mappers.azure_ad import AzureADMapper from cluster_agent.identity.slurm_user.mappers.mapper_base import SlurmUserMapper from cluster_agent.identity.slurm_user.mappers.ldap import LDAPMapper from cluster_agent.identity.slurm_user.mappers.single_user import SingleUserMapper @@ -7,4 +8,5 @@ "SlurmUserMapper", "LDAPMapper", "SingleUserMapper", + "AzureADMapper", ] diff --git a/cluster_agent/identity/slurm_user/mappers/azure_ad.py b/cluster_agent/identity/slurm_user/mappers/azure_ad.py new file mode 100644 index 0000000..236f54f --- /dev/null +++ b/cluster_agent/identity/slurm_user/mappers/azure_ad.py @@ -0,0 +1,89 @@ +import httpx +import typing + +from loguru import logger +from pydantic import BaseModel, EmailStr, Extra + +from cluster_agent.utils.logging import log_error + +from cluster_agent.identity.slurm_user.exceptions import AzureADError +from cluster_agent.identity.slurm_user.mappers.mapper_base import SlurmUserMapper +from cluster_agent.identity.cluster_api import backend_client + + +class MemberIdentity(BaseModel, extra=Extra.ignore): + provider: str + access_token: str + user_id: str + + +class MemberDetail(BaseModel, extra=Extra.ignore): + email: EmailStr + identities: typing.List[MemberIdentity] + user_id: str + + +class Member(BaseModel, extra=Extra.ignore): + user_id: str + email: EmailStr + name: str + + +class MemberList(BaseModel, extra=Extra.ignore): + members: typing.List[Member] + + +class AzureResponse(BaseModel, extra=Extra.ignore): + mailNickName: str + + +class AzureADMapper(SlurmUserMapper): + """ + Provide a class to interface with the Azure AD for slurm user mapping. + """ + + async def find_username(self, email: str) -> str: + """ + Find an Azure AD username given a user email. + """ + + with AzureADError.handle_errors( + "Failed to fetch username from Azure AD", + do_except=log_error, + ): + logger.debug(f"Searching for email {email} on Admin API") + response = await backend_client.get( + "/admin/management/organizations/members", + params=dict(search=email), + ) + response.raise_for_status() + + member_list = MemberList.parse_raw(response.content) + AzureADError.require_condition( + len(member_list.members) == 1, + f"Did not find exactly one match for email {email}", + ) + member = member_list.members.pop() + + logger.debug("Getting azure access token for user") + response = await backend_client.get( + f"/admin/management/organizations/members/{member.user_id}", + ) + response.raise_for_status() + + member_detail = MemberDetail.parse_raw(response.content) + AzureADError.require_condition( + len(member_detail.identities) == 1, + "Did not find exactly one embedded identity for the user", + ) + identity = member_detail.identities.pop() + + logger.debug("Requesting username from Azure AD") + response = httpx.get( + "https://graph.microsoft.com/v1.0/me?$select=mailNickName", + headers=dict(Authorization=f"Bearer {identity.access_token}"), + ) + response.raise_for_status() + azure_details = AzureResponse.parse_raw(response.content) + + return azure_details.mailNickName.lower() diff --git a/cluster_agent/identity/slurm_user/mappers/ldap.py b/cluster_agent/identity/slurm_user/mappers/ldap.py index 1006e6b..5be7ece 100644 --- a/cluster_agent/identity/slurm_user/mappers/ldap.py +++ b/cluster_agent/identity/slurm_user/mappers/ldap.py @@ -15,9 +15,10 @@ class LDAPMapper(SlurmUserMapper): """ Provide a class to interface with the LDAP server """ + connection = None - def configure(self, settings: SlurmUserSettings): + async def configure(self, settings: SlurmUserSettings): """ Connect to the the LDAP server. """ @@ -66,7 +67,7 @@ def configure(self, settings: SlurmUserSettings): self.connection.start_tls() self.connection.bind() - def find_username(self, email: str) -> str: + async def find_username(self, email: str) -> str: """ Find an active diretory username given a user email. diff --git a/cluster_agent/identity/slurm_user/mappers/mapper_base.py b/cluster_agent/identity/slurm_user/mappers/mapper_base.py index 04c7955..0787e7a 100644 --- a/cluster_agent/identity/slurm_user/mappers/mapper_base.py +++ b/cluster_agent/identity/slurm_user/mappers/mapper_base.py @@ -16,17 +16,16 @@ class SlurmUserMapper: - find_username(): Map a provided email address to a local slurm user. """ - @abc.abstractmethod - def configure(self, settings: SlurmUserSettings): + async def configure(self, settings: SlurmUserSettings): """ Configure the mapper instance. - Must be implemented by any derived class + May be overridden by any derived class """ - raise NotImplementedError + pass @abc.abstractmethod - def find_username(self, email: str) -> str: + async def find_username(self, email: str) -> str: """ Find a slurm user name given an email. diff --git a/cluster_agent/identity/slurm_user/mappers/single_user.py b/cluster_agent/identity/slurm_user/mappers/single_user.py index 2e34e88..25761ae 100644 --- a/cluster_agent/identity/slurm_user/mappers/single_user.py +++ b/cluster_agent/identity/slurm_user/mappers/single_user.py @@ -11,15 +11,16 @@ class SingleUserMapper(SlurmUserMapper): """ Provide a class to interface with the LDAP server """ + submitter = None - def configure(self, settings: SlurmUserSettings): + async def configure(self, settings: SlurmUserSettings): """ Connect to the the LDAP server. """ self.submitter = settings.SINGLE_USER_SUBMITTER - def find_username(self, *_) -> str: + async def find_username(self, *_) -> str: """ Find an active diretory username given a user email. diff --git a/cluster_agent/jobbergate/submit.py b/cluster_agent/jobbergate/submit.py index ecd1eff..1077a42 100644 --- a/cluster_agent/jobbergate/submit.py +++ b/cluster_agent/jobbergate/submit.py @@ -48,7 +48,7 @@ async def submit_job_script( name = pending_job_submission.application_name mapper_class_name = user_mapper.__class__.__name__ logger.debug(f"Fetching username for email {email} with mapper {mapper_class_name}") - username = user_mapper.find_username(email) + username = await user_mapper.find_username(email) logger.debug(f"Using local slurm user {username} for job submission") JobSubmissionError.require_condition( @@ -102,7 +102,7 @@ async def submit_pending_jobs(): logger.debug("Started submitting pending jobs...") logger.debug("Building user-mapper") - user_mapper = manufacture() + user_mapper = await manufacture() logger.debug("Fetching pending jobs...") pending_job_submissions = await fetch_pending_submissions() diff --git a/requirements.txt b/requirements.txt index 60d4714..cbb0dfb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,3 +11,4 @@ PyJWT==2.1.0 httpx==0.22.0 py-buzz==3.1.0 ldap3==2.9.1 +email-validator==1.1.3 diff --git a/tests/identity/slurm_user/mappers/test_azure_ad.py b/tests/identity/slurm_user/mappers/test_azure_ad.py new file mode 100644 index 0000000..f559e74 --- /dev/null +++ b/tests/identity/slurm_user/mappers/test_azure_ad.py @@ -0,0 +1,319 @@ +""" +Define tests for the Azure AD mapper. +""" + +import httpx +import pytest +import respx + +from cluster_agent.settings import SETTINGS +from cluster_agent.identity.slurm_user.exceptions import AzureADError +from cluster_agent.identity.slurm_user.mappers import azure_ad + + +async def test_find_username__success(): + """ + Test that an AzureADMapper can fetch a username given an email. + """ + async with respx.mock: + respx.post(f"https://{SETTINGS.AUTH0_DOMAIN}/oauth/token").mock( + return_value=httpx.Response( + status_code=200, + json=dict(access_token="dummy-token"), + ) + ) + + members_route = respx.get( + f"{SETTINGS.BASE_API_URL}/admin/management/organizations/members" + ) + members_route.mock( + return_value=httpx.Response( + status_code=200, + json=dict( + members=[ + dict( + user_id="dummy-id", + email="dummy_user@dummy.domain.com", + name="Dummy Dummerson", + ), + ], + ), + ) + ) + + details_route = respx.get( + f"{SETTINGS.BASE_API_URL}/admin/management/organizations/members/dummy-id" + ) + details_route.mock( + return_value=httpx.Response( + status_code=200, + json=dict( + email="dummy_user@dummy.domain.com", + identities=[ + dict( + provider="dummy provider", + access_token="dummy-azure-token", + user_id="dummy-id", + ), + ], + user_id="dummy-id", + ), + ) + ) + + azure_route = respx.get( + "https://graph.microsoft.com/v1.0/me?$select=mailNickName" + ) + azure_route.mock( + return_value=httpx.Response( + status_code=200, + json=dict( + mailNickName="DDU00D", + ), + ) + ) + + mapper = azure_ad.AzureADMapper() + + username = await mapper.find_username("dummy_user@dummy.domain.com") + assert username == "ddu00d" + + +async def test_find_username__fails_if_email_search_fails(): + """ + Test that an AzureADMapper raises an exception if no user is found for the email. + """ + async with respx.mock: + respx.post(f"https://{SETTINGS.AUTH0_DOMAIN}/oauth/token").mock( + return_value=httpx.Response( + status_code=200, + json=dict(access_token="dummy-token"), + ) + ) + + members_route = respx.get( + f"{SETTINGS.BASE_API_URL}/admin/management/organizations/members" + ) + members_route.mock( + return_value=httpx.Response( + status_code=404, + ) + ) + + with pytest.raises(AzureADError, match="Failed to fetch username from Azure AD"): + mapper = azure_ad.AzureADMapper() + await mapper.find_username("dummy_user@dummy.domain.com") + + +async def test_find_username__fails_if_email_search_has_multiple_hits(): + """ + Test that an AzureADMapper raises an exception if multiple users are found. + """ + async with respx.mock: + respx.post(f"https://{SETTINGS.AUTH0_DOMAIN}/oauth/token").mock( + return_value=httpx.Response( + status_code=200, + json=dict(access_token="dummy-token"), + ) + ) + + members_route = respx.get( + f"{SETTINGS.BASE_API_URL}/admin/management/organizations/members" + ) + members_route.mock( + return_value=httpx.Response( + status_code=200, + json=dict( + members=[ + dict( + user_id="dummy-id", + email="dummy_user@dummy.domain.com", + name="Dummy Dummerson", + ), + dict( + user_id="stupid-id", + email="dummy_user@dummy.domain.com", + name="Stupid Vanderstupid", + ), + ], + ), + ) + ) + + with pytest.raises( + AzureADError, + match="Failed to fetch username from Azure AD.*Did not find exactly one", + ): + mapper = azure_ad.AzureADMapper() + await mapper.find_username("dummy_user@dummy.domain.com") + + +async def test_find_username__fails_if_detail_request_fails(): + """ + Test that an AzureADMapper raises an error if the request for user details fails. + """ + async with respx.mock: + respx.post(f"https://{SETTINGS.AUTH0_DOMAIN}/oauth/token").mock( + return_value=httpx.Response( + status_code=200, + json=dict(access_token="dummy-token"), + ) + ) + + members_route = respx.get( + f"{SETTINGS.BASE_API_URL}/admin/management/organizations/members" + ) + members_route.mock( + return_value=httpx.Response( + status_code=200, + json=dict( + members=[ + dict( + user_id="dummy-id", + email="dummy_user@dummy.domain.com", + name="Dummy Dummerson", + ), + ], + ), + ) + ) + + details_route = respx.get( + f"{SETTINGS.BASE_API_URL}/admin/management/organizations/members/dummy-id" + ) + details_route.mock( + return_value=httpx.Response( + status_code=400, + ) + ) + + with pytest.raises(AzureADError, match="Failed to fetch username"): + mapper = azure_ad.AzureADMapper() + await mapper.find_username("dummy_user@dummy.domain.com") + + +async def test_find_username__fails_if_multiple_identities_found(): + """ + Test that an AzureADMapper raises an error if multiple identities are found. + """ + async with respx.mock: + respx.post(f"https://{SETTINGS.AUTH0_DOMAIN}/oauth/token").mock( + return_value=httpx.Response( + status_code=200, + json=dict(access_token="dummy-token"), + ) + ) + + members_route = respx.get( + f"{SETTINGS.BASE_API_URL}/admin/management/organizations/members" + ) + members_route.mock( + return_value=httpx.Response( + status_code=200, + json=dict( + members=[ + dict( + user_id="dummy-id", + email="dummy_user@dummy.domain.com", + name="Dummy Dummerson", + ), + ], + ), + ) + ) + + details_route = respx.get( + f"{SETTINGS.BASE_API_URL}/admin/management/organizations/members/dummy-id" + ) + details_route.mock( + return_value=httpx.Response( + status_code=200, + json=dict( + email="dummy_user@dummy.domain.com", + identities=[ + dict( + provider="dummy provider", + access_token="dummy-azure-token", + user_id="dummy-id", + ), + dict( + provider="stupid provider", + access_token="stupid-azure-token", + user_id="stupid-id", + ), + ], + user_id="dummy-id", + ), + ) + ) + + with pytest.raises( + AzureADError, + match="Failed to fetch username.*Did not find exactly one embedded", + ): + mapper = azure_ad.AzureADMapper() + await mapper.find_username("dummy_user@dummy.domain.com") + + +async def test_find_username__fails_if_microsoft_graph_call_fails(): + """ + Test that an AzureADMapper raises an exception if the call to the graph api fails. + """ + async with respx.mock: + respx.post(f"https://{SETTINGS.AUTH0_DOMAIN}/oauth/token").mock( + return_value=httpx.Response( + status_code=200, + json=dict(access_token="dummy-token"), + ) + ) + + members_route = respx.get( + f"{SETTINGS.BASE_API_URL}/admin/management/organizations/members" + ) + members_route.mock( + return_value=httpx.Response( + status_code=200, + json=dict( + members=[ + dict( + user_id="dummy-id", + email="dummy_user@dummy.domain.com", + name="Dummy Dummerson", + ), + ], + ), + ) + ) + + details_route = respx.get( + f"{SETTINGS.BASE_API_URL}/admin/management/organizations/members/dummy-id" + ) + details_route.mock( + return_value=httpx.Response( + status_code=200, + json=dict( + email="dummy_user@dummy.domain.com", + identities=[ + dict( + provider="dummy provider", + access_token="dummy-azure-token", + user_id="dummy-id", + ), + ], + user_id="dummy-id", + ), + ) + ) + + azure_route = respx.get( + "https://graph.microsoft.com/v1.0/me?$select=mailNickName" + ) + azure_route.mock( + return_value=httpx.Response( + status_code=400, + ) + ) + + with pytest.raises(AzureADError, match="Failed to fetch username"): + mapper = azure_ad.AzureADMapper() + await mapper.find_username("dummy_user@dummy.domain.com") diff --git a/tests/identity/slurm_user/mappers/test_ldap.py b/tests/identity/slurm_user/mappers/test_ldap.py index cae0420..233f98f 100644 --- a/tests/identity/slurm_user/mappers/test_ldap.py +++ b/tests/identity/slurm_user/mappers/test_ldap.py @@ -11,7 +11,7 @@ from cluster_agent.identity.slurm_user.settings import SLURM_USER_SETTINGS -def test_configure__success(mocker, tweak_slurm_user_settings): +async def test_configure__success(mocker, tweak_slurm_user_settings): """ Test that an LDAP instance will ``configure()`` if settings are correct. """ @@ -26,7 +26,7 @@ def test_configure__success(mocker, tweak_slurm_user_settings): LDAP_USERNAME="dummyUser", LDAP_PASSWORD="dummy-password", ): - mapper.configure(SLURM_USER_SETTINGS) + await mapper.configure(SLURM_USER_SETTINGS) mock_server.assert_called_once_with(SLURM_USER_SETTINGS.LDAP_DOMAIN, get_info=ldap.ALL) mock_connection.assert_called_once_with( mock_server_obj, @@ -37,7 +37,7 @@ def test_configure__success(mocker, tweak_slurm_user_settings): assert mapper.search_base == "DC=dummy,DC=domain,DC=com" -def test_configure__sets_up_ntlm_auth_type_correctly(mocker, tweak_slurm_user_settings): +async def test_configure__sets_up_ntlm_auth_type_correctly(mocker, tweak_slurm_user_settings): """ Test that an LDAP instance will ``configure()`` NTLM auth correctly. """ @@ -53,7 +53,7 @@ def test_configure__sets_up_ntlm_auth_type_correctly(mocker, tweak_slurm_user_se LDAP_PASSWORD="dummy-password", LDAP_AUTH_TYPE="NTLM", ): - mapper.configure(SLURM_USER_SETTINGS) + await mapper.configure(SLURM_USER_SETTINGS) mock_server.assert_called_once_with(SLURM_USER_SETTINGS.LDAP_DOMAIN, get_info=ldap.ALL) mock_connection.assert_called_once_with( mock_server_obj, @@ -64,7 +64,7 @@ def test_configure__sets_up_ntlm_auth_type_correctly(mocker, tweak_slurm_user_se assert mapper.search_base == "DC=dummy,DC=domain,DC=com" -def test_configure__raises_LDAPError_if_settings_are_missing(tweak_slurm_user_settings): +async def test_configure__raises_LDAPError_if_settings_are_missing(tweak_slurm_user_settings): """ Test that the ``configure()`` method will fail if settings are not correct. @@ -80,7 +80,7 @@ def test_configure__raises_LDAPError_if_settings_are_missing(tweak_slurm_user_se LDAP_PASSWORD="dummy-password", ): with pytest.raises(LDAPError, match="LDAP is not configured"): - mapper.configure(SLURM_USER_SETTINGS) + await mapper.configure(SLURM_USER_SETTINGS) with tweak_slurm_user_settings( LDAP_DOMAIN="dummy.domain.com", @@ -89,7 +89,7 @@ def test_configure__raises_LDAPError_if_settings_are_missing(tweak_slurm_user_se LDAP_PASSWORD="dummy-password", ): with pytest.raises(LDAPError, match="LDAP is not configured"): - mapper.configure(SLURM_USER_SETTINGS) + await mapper.configure(SLURM_USER_SETTINGS) with tweak_slurm_user_settings( LDAP_DOMAIN="dummy.domain.com", @@ -98,7 +98,7 @@ def test_configure__raises_LDAPError_if_settings_are_missing(tweak_slurm_user_se LDAP_PASSWORD="dummy-password", ): with pytest.raises(LDAPError, match="LDAP is not configured"): - mapper.configure(SLURM_USER_SETTINGS) + await mapper.configure(SLURM_USER_SETTINGS) with tweak_slurm_user_settings( LDAP_DOMAIN="dummy.domain.com", @@ -107,10 +107,10 @@ def test_configure__raises_LDAPError_if_settings_are_missing(tweak_slurm_user_se LDAP_PASSWORD=None, ): with pytest.raises(LDAPError, match="LDAP is not configured"): - mapper.configure(SLURM_USER_SETTINGS) + await mapper.configure(SLURM_USER_SETTINGS) -def test_find_username__success(mocker, tweak_slurm_user_settings): +async def test_find_username__success(mocker, tweak_slurm_user_settings): """ Test that the ``find_username()`` gets username from ldap server given email. @@ -138,12 +138,15 @@ def test_find_username__success(mocker, tweak_slurm_user_settings): LDAP_USERNAME="dummyUser", LDAP_PASSWORD="dummy-password", ): - mapper.configure(SLURM_USER_SETTINGS) - username = mapper.find_username("dummy_user@dummy.domain.com") + await mapper.configure(SLURM_USER_SETTINGS) + username = await mapper.find_username("dummy_user@dummy.domain.com") assert username == "xxx00x" -def test_find_username__fails_if_server_does_not_return_1_entry(mocker, tweak_slurm_user_settings): +async def test_find_username__fails_if_server_does_not_return_1_entry( + mocker, + tweak_slurm_user_settings, +): """ Test that the ``find_username()`` fails if server does not return exactly 1 entry. @@ -166,18 +169,21 @@ def test_find_username__fails_if_server_does_not_return_1_entry(mocker, tweak_sl LDAP_USERNAME="dummyUser", LDAP_PASSWORD="dummy-password", ): - mapper.configure(SLURM_USER_SETTINGS) + await mapper.configure(SLURM_USER_SETTINGS) mock_connection_obj.entries = [] with pytest.raises(LDAPError, match="Did not find exactly one"): - mapper.find_username("dummy_user@dummy.domain.com") + await mapper.find_username("dummy_user@dummy.domain.com") mock_connection_obj.entries = [1, 2] with pytest.raises(LDAPError, match="Did not find exactly one"): - mapper.find_username("dummy_user@dummy.domain.com") + await mapper.find_username("dummy_user@dummy.domain.com") -def test_find_username__fails_if_entries_cannot_be_extracted(mocker, tweak_slurm_user_settings): +async def test_find_username__fails_if_entries_cannot_be_extracted( + mocker, + tweak_slurm_user_settings, +): """ Test that the ``find_username()`` fails if entries are invalid. @@ -204,12 +210,12 @@ def test_find_username__fails_if_entries_cannot_be_extracted(mocker, tweak_slurm LDAP_USERNAME="dummyUser", LDAP_PASSWORD="dummy-password", ): - mapper.configure(SLURM_USER_SETTINGS) + await mapper.configure(SLURM_USER_SETTINGS) with pytest.raises(LDAPError, match="Failed to extract data"): - mapper.find_username("dummy_user@dummy.domain.com") + await mapper.find_username("dummy_user@dummy.domain.com") -def test_find_username__fails_if_user_has_more_than_one_CN(mocker, tweak_slurm_user_settings): +async def test_find_username__fails_if_user_has_more_than_one_CN(mocker, tweak_slurm_user_settings): """ Test that the ``find_username()`` fails if a user has more than one username. @@ -236,11 +242,11 @@ def test_find_username__fails_if_user_has_more_than_one_CN(mocker, tweak_slurm_u LDAP_USERNAME="dummyUser", LDAP_PASSWORD="dummy-password", ): - mapper.configure(SLURM_USER_SETTINGS) + await mapper.configure(SLURM_USER_SETTINGS) mock_entry.entry_to_json = lambda: json.dumps(dict(attributes=dict(cn=[]))) with pytest.raises(LDAPError, match="User did not have exactly one CN"): - mapper.find_username("dummy_user@dummy.domain.com") + await mapper.find_username("dummy_user@dummy.domain.com") mock_entry.entry_to_json = lambda: json.dumps(dict(attributes=dict(cn=[1, 2]))) with pytest.raises(LDAPError, match="User did not have exactly one CN"): - mapper.find_username("dummy_user@dummy.domain.com") + await mapper.find_username("dummy_user@dummy.domain.com") diff --git a/tests/identity/slurm_user/mappers/test_single_user.py b/tests/identity/slurm_user/mappers/test_single_user.py index 1dd5dc7..8241943 100644 --- a/tests/identity/slurm_user/mappers/test_single_user.py +++ b/tests/identity/slurm_user/mappers/test_single_user.py @@ -9,33 +9,33 @@ from cluster_agent.identity.slurm_user.settings import SLURM_USER_SETTINGS -def test_configure__success(mocker, tweak_slurm_user_settings): +async def test_configure__success(mocker, tweak_slurm_user_settings): """ Test that a SingleUserMapper instance ``configures()`` correctly. """ mapper = single_user.SingleUserMapper() with tweak_slurm_user_settings(SINGLE_USER_SUBMITTER="dummy-user"): - mapper.configure(SLURM_USER_SETTINGS) + await mapper.configure(SLURM_USER_SETTINGS) assert mapper.submitter == "dummy-user" -def test_find_username__success(mocker, tweak_slurm_user_settings): +async def test_find_username__success(mocker, tweak_slurm_user_settings): """ Test that the ``find_username()`` uses the single user as the submit username. """ mapper = single_user.SingleUserMapper() with tweak_slurm_user_settings(SINGLE_USER_SUBMITTER="dummy-user"): - mapper.configure(SLURM_USER_SETTINGS) - username = mapper.find_username("does.not@matter.com") + await mapper.configure(SLURM_USER_SETTINGS) + username = await mapper.find_username("does.not@matter.com") assert username == "dummy-user" -def test_find_username__fails_mapper_is_not_configured(mocker, tweak_slurm_user_settings): +async def test_find_username__fails_mapper_is_not_configured(mocker, tweak_slurm_user_settings): """ Test that the ``find_username()`` fails if the mapper is not configured. """ mapper = single_user.SingleUserMapper() with pytest.raises(SingleUserError, match="No username set"): - mapper.find_username("dummy_user@dummy.domain.com") + await mapper.find_username("dummy_user@dummy.domain.com") diff --git a/tests/identity/slurm_user/test_factory.py b/tests/identity/slurm_user/test_factory.py index 5848aa9..68c73ba 100644 --- a/tests/identity/slurm_user/test_factory.py +++ b/tests/identity/slurm_user/test_factory.py @@ -2,27 +2,28 @@ from cluster_agent.identity.slurm_user.exceptions import MapperFactoryError from cluster_agent.identity.slurm_user.factory import manufacture +from cluster_agent.identity.slurm_user.mappers.ldap import LDAPMapper from cluster_agent.identity.slurm_user.settings import SLURM_USER_SETTINGS from cluster_agent.identity.slurm_user.constants import MapperType -def test_manufacture__with_valid_mapper_name(tweak_slurm_user_settings, mocker): - mocked_ldap_instance = mocker.MagicMock() +async def test_manufacture__with_valid_mapper_name(tweak_slurm_user_settings, mocker): + mocked_ldap_instance = mocker.AsyncMock(LDAPMapper) mocked_ldap_class = mocker.MagicMock(return_value=mocked_ldap_instance) mocker.patch.dict( "cluster_agent.identity.slurm_user.factory.mapper_map", {MapperType.LDAP: mocked_ldap_class}, ) with tweak_slurm_user_settings(SLURM_USER_MAPPER=MapperType.LDAP): - manufacture() + await manufacture() mocked_ldap_instance.configure.assert_called_once_with(SLURM_USER_SETTINGS) -def test_manufacture__with_invalid_mapper_name(tweak_slurm_user_settings, mocker): +async def test_manufacture__with_invalid_mapper_name(tweak_slurm_user_settings, mocker): mocker.patch.dict( "cluster_agent.identity.slurm_user.factory.mapper_map", {"REAL": mocker.MagicMock()}, ) with tweak_slurm_user_settings(SLURM_USER_MAPPER="FAKE"): with pytest.raises(MapperFactoryError, match="Couldn't find a mapper class"): - manufacture() + await manufacture() diff --git a/tests/jobbergate/test_submit.py b/tests/jobbergate/test_submit.py index bca04fa..ae67c79 100644 --- a/tests/jobbergate/test_submit.py +++ b/tests/jobbergate/test_submit.py @@ -9,8 +9,9 @@ import pytest import respx -from cluster_agent.utils.exception import JobSubmissionError, SlurmrestdError +from cluster_agent.identity.slurm_user.mappers.mapper_base import SlurmUserMapper from cluster_agent.settings import SETTINGS +from cluster_agent.utils.exception import JobSubmissionError, SlurmrestdError from cluster_agent.jobbergate.schemas import ( PendingJobSubmission, SlurmJobSubmission, @@ -31,7 +32,7 @@ async def test_submit_job_script__success( and that a ``slurm_job_id`` is returned. Verifies that LDAP was used to retrieve the username. """ - user_mapper = mocker.MagicMock() + user_mapper = mocker.AsyncMock(SlurmUserMapper) user_mapper.find_username.return_value = "dummy-user" mocker.patch( @@ -145,7 +146,7 @@ async def test_submit_job_script__raises_exception_if_no_executable_script_was_f ) with pytest.raises(JobSubmissionError, match="Could not find an executable"): - await submit_job_script(pending_job_submission, mocker.MagicMock()) + await submit_job_script(pending_job_submission, mocker.AsyncMock(SlurmUserMapper)) @pytest.mark.asyncio @@ -158,7 +159,7 @@ async def test_submit_job_script__raises_exception_if_submit_call_response_is_no REST API is nota 200. Verifies that the error message is included in the raised exception. """ - user_mapper = mocker.MagicMock() + user_mapper = mocker.AsyncMock(SlurmUserMapper) user_mapper.find_username.return_value = "dummy-user" mocker.patch( @@ -201,7 +202,7 @@ async def test_submit_job_script__raises_exception_if_response_cannot_be_unpacke REST API is nota 200. Verifies that the error message is included in the raised exception. """ - user_mapper = mocker.MagicMock() + user_mapper = mocker.AsyncMock(SlurmUserMapper) user_mapper.find_username.return_value = "dummy-user" mocker.patch(