Skip to content

Commit

Permalink
Merge branch 'tucker/add-slurm-user-mapper-with-azure' into tucker/1.…
Browse files Browse the repository at this point in the history
…6.0--release-candidate
  • Loading branch information
dusktreader committed Apr 28, 2022
2 parents 39e4e55 + 25d591a commit 2fec859
Show file tree
Hide file tree
Showing 15 changed files with 480 additions and 53 deletions.
1 change: 1 addition & 0 deletions cluster_agent/identity/slurm_user/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class MapperType(str, Enum):

LDAP = "LDAP"
SINGLE_USER = "SINGLE_USER"
AZURE_AD = "AZURE_AD"


class LDAPAuthType(str, Enum):
Expand Down
4 changes: 4 additions & 0 deletions cluster_agent/identity/slurm_user/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@ class LDAPError(ClusterAgentError):

class SingleUserError(ClusterAgentError):
"""Raise exception when there is a problem with single-user submission."""


class AzureADError(ClusterAgentError):
"""Raise exception when communication with Azure AD fails."""
6 changes: 4 additions & 2 deletions cluster_agent/identity/slurm_user/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,19 @@
from cluster_agent.identity.slurm_user.mappers import (
SlurmUserMapper,
LDAPMapper,
AzureADMapper,
SingleUserMapper,
)


mapper_map = {
MapperType.LDAP: LDAPMapper,
MapperType.AZURE_AD: AzureADMapper,
MapperType.SINGLE_USER: SingleUserMapper,
}


def manufacture() -> SlurmUserMapper:
async def manufacture() -> SlurmUserMapper:
"""
Create an instance of a Slurm user mapper given the app configuration.
Expand All @@ -32,5 +34,5 @@ def manufacture() -> SlurmUserMapper:
)
assert mapper_class is not None
mapper_instance = mapper_class()
mapper_instance.configure(SLURM_USER_SETTINGS)
await mapper_instance.configure(SLURM_USER_SETTINGS)
return mapper_instance
2 changes: 2 additions & 0 deletions cluster_agent/identity/slurm_user/mappers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from cluster_agent.identity.slurm_user.mappers.azure_ad import AzureADMapper
from cluster_agent.identity.slurm_user.mappers.mapper_base import SlurmUserMapper
from cluster_agent.identity.slurm_user.mappers.ldap import LDAPMapper
from cluster_agent.identity.slurm_user.mappers.single_user import SingleUserMapper
Expand All @@ -7,4 +8,5 @@
"SlurmUserMapper",
"LDAPMapper",
"SingleUserMapper",
"AzureADMapper",
]
89 changes: 89 additions & 0 deletions cluster_agent/identity/slurm_user/mappers/azure_ad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import httpx
import typing

from loguru import logger
from pydantic import BaseModel, EmailStr, Extra

from cluster_agent.utils.logging import log_error

from cluster_agent.identity.slurm_user.exceptions import AzureADError
from cluster_agent.identity.slurm_user.mappers.mapper_base import SlurmUserMapper
from cluster_agent.identity.cluster_api import backend_client


class MemberIdentity(BaseModel, extra=Extra.ignore):
provider: str
access_token: str
user_id: str


class MemberDetail(BaseModel, extra=Extra.ignore):
email: EmailStr
identities: typing.List[MemberIdentity]
user_id: str


class Member(BaseModel, extra=Extra.ignore):
user_id: str
email: EmailStr
name: str


class MemberList(BaseModel, extra=Extra.ignore):
members: typing.List[Member]


class AzureResponse(BaseModel, extra=Extra.ignore):
mailNickName: str


class AzureADMapper(SlurmUserMapper):
"""
Provide a class to interface with the Azure AD for slurm user mapping.
"""

async def find_username(self, email: str) -> str:
"""
Find an Azure AD username given a user email.
"""

with AzureADError.handle_errors(
"Failed to fetch username from Azure AD",
do_except=log_error,
):
logger.debug(f"Searching for email {email} on Admin API")
response = await backend_client.get(
"/admin/management/organizations/members",
params=dict(search=email),
)
response.raise_for_status()

member_list = MemberList.parse_raw(response.content)
AzureADError.require_condition(
len(member_list.members) == 1,
f"Did not find exactly one match for email {email}",
)
member = member_list.members.pop()

logger.debug("Getting azure access token for user")
response = await backend_client.get(
f"/admin/management/organizations/members/{member.user_id}",
)
response.raise_for_status()

member_detail = MemberDetail.parse_raw(response.content)
AzureADError.require_condition(
len(member_detail.identities) == 1,
"Did not find exactly one embedded identity for the user",
)
identity = member_detail.identities.pop()

logger.debug("Requesting username from Azure AD")
response = httpx.get(
"https://graph.microsoft.com/v1.0/me?$select=mailNickName",
headers=dict(Authorization=f"Bearer {identity.access_token}"),
)
response.raise_for_status()
azure_details = AzureResponse.parse_raw(response.content)

return azure_details.mailNickName.lower()
5 changes: 3 additions & 2 deletions cluster_agent/identity/slurm_user/mappers/ldap.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ class LDAPMapper(SlurmUserMapper):
"""
Provide a class to interface with the LDAP server
"""

connection = None

def configure(self, settings: SlurmUserSettings):
async def configure(self, settings: SlurmUserSettings):
"""
Connect to the the LDAP server.
"""
Expand Down Expand Up @@ -66,7 +67,7 @@ def configure(self, settings: SlurmUserSettings):
self.connection.start_tls()
self.connection.bind()

def find_username(self, email: str) -> str:
async def find_username(self, email: str) -> str:
"""
Find an active diretory username given a user email.
Expand Down
9 changes: 4 additions & 5 deletions cluster_agent/identity/slurm_user/mappers/mapper_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,16 @@ class SlurmUserMapper:
- find_username(): Map a provided email address to a local slurm user.
"""

@abc.abstractmethod
def configure(self, settings: SlurmUserSettings):
async def configure(self, settings: SlurmUserSettings):
"""
Configure the mapper instance.
Must be implemented by any derived class
May be overridden by any derived class
"""
raise NotImplementedError
pass

@abc.abstractmethod
def find_username(self, email: str) -> str:
async def find_username(self, email: str) -> str:
"""
Find a slurm user name given an email.
Expand Down
5 changes: 3 additions & 2 deletions cluster_agent/identity/slurm_user/mappers/single_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@ class SingleUserMapper(SlurmUserMapper):
"""
Provide a class to interface with the LDAP server
"""

submitter = None

def configure(self, settings: SlurmUserSettings):
async def configure(self, settings: SlurmUserSettings):
"""
Connect to the the LDAP server.
"""
self.submitter = settings.SINGLE_USER_SUBMITTER

def find_username(self, *_) -> str:
async def find_username(self, *_) -> str:
"""
Find an active diretory username given a user email.
Expand Down
4 changes: 2 additions & 2 deletions cluster_agent/jobbergate/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ async def submit_job_script(
name = pending_job_submission.application_name
mapper_class_name = user_mapper.__class__.__name__
logger.debug(f"Fetching username for email {email} with mapper {mapper_class_name}")
username = user_mapper.find_username(email)
username = await user_mapper.find_username(email)
logger.debug(f"Using local slurm user {username} for job submission")

JobSubmissionError.require_condition(
Expand Down Expand Up @@ -102,7 +102,7 @@ async def submit_pending_jobs():
logger.debug("Started submitting pending jobs...")

logger.debug("Building user-mapper")
user_mapper = manufacture()
user_mapper = await manufacture()

logger.debug("Fetching pending jobs...")
pending_job_submissions = await fetch_pending_submissions()
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ PyJWT==2.1.0
httpx==0.22.0
py-buzz==3.1.0
ldap3==2.9.1
email-validator==1.1.3
Loading

0 comments on commit 2fec859

Please sign in to comment.