diff --git a/src/dsmlp/app/id_validator.py b/src/dsmlp/app/id_validator.py index c99f11d..99888aa 100644 --- a/src/dsmlp/app/id_validator.py +++ b/src/dsmlp/app/id_validator.py @@ -12,12 +12,13 @@ from dsmlp.plugin.logger import Logger from dsmlp.app.types import * + class IDValidator(ComponentValidator): - + def __init__(self, awsed: AwsedClient, logger: Logger) -> None: self.awsed = awsed self.logger = logger - + def validate_pod(self, request: Request): """ Validate pods for namespaces with the 'k8s-sync' label @@ -28,7 +29,8 @@ def validate_pod(self, request: Request): # if 'k8s-sync' in namespace.labels: user = self.awsed.describe_user(username) if not user: - raise ValidationFailure(f"namespace: no AWSEd user found with username {username}") + raise ValidationFailure( + f"namespace: no AWSEd user found with username {username}") allowed_uid = user.uid allowed_courses = user.enrollments @@ -39,15 +41,20 @@ def validate_pod(self, request: Request): metadata = request.object.metadata spec = request.object.spec - self.validate_course_enrollment(allowed_courses, metadata.labels) - self.validate_pod_security_context(allowed_uid, allowed_gids, spec.securityContext) + + if metadata is not None and metadata.labels is not None: + self.validate_course_enrollment(allowed_courses, metadata.labels) + + self.validate_pod_security_context( + allowed_uid, allowed_gids, spec.securityContext) self.validate_containers(allowed_uid, allowed_gids, spec) def validate_course_enrollment(self, allowed_courses: List[str], labels: Dict[str, str]): if not 'dsmlp/course' in labels: return if not labels['dsmlp/course'] in allowed_courses: - raise ValidationFailure(f"metadata.labels: dsmlp/course must be in range {allowed_courses}") + raise ValidationFailure( + f"metadata.labels: dsmlp/course must be in range {allowed_courses}") def validate_pod_security_context( self, @@ -59,18 +66,22 @@ def validate_pod_security_context( return if securityContext.runAsUser is not None and authorized_uid != securityContext.runAsUser: - raise ValidationFailure(f"spec.securityContext: uid must be in range [{authorized_uid}]") + raise ValidationFailure( + f"spec.securityContext: uid must be in range [{authorized_uid}]") if securityContext.runAsGroup is not None and securityContext.runAsGroup not in allowed_teams: - raise ValidationFailure(f"spec.securityContext: gid must be in range {allowed_teams}") + raise ValidationFailure( + f"spec.securityContext: gid must be in range {allowed_teams}") if securityContext.fsGroup is not None and securityContext.fsGroup not in allowed_teams: - raise ValidationFailure(f"spec.securityContext: gid must be in range {allowed_teams}") + raise ValidationFailure( + f"spec.securityContext: gid must be in range {allowed_teams}") if securityContext.supplementalGroups is not None: for sgroup in securityContext.supplementalGroups: if not sgroup in allowed_teams: - raise ValidationFailure(f"spec.securityContext: gid must be in range {allowed_teams}") + raise ValidationFailure( + f"spec.securityContext: gid must be in range {allowed_teams}") def validate_containers( self, @@ -81,8 +92,10 @@ def validate_containers( """ Validate the security context of containers and initContainers """ - self.validate_security_contexts(authorized_uid, allowed_teams, spec.containers, "containers") - self.validate_security_contexts(authorized_uid, allowed_teams, spec.initContainers, "initContainers") + self.validate_security_contexts( + authorized_uid, allowed_teams, spec.containers, "containers") + self.validate_security_contexts( + authorized_uid, allowed_teams, spec.initContainers, "initContainers") def validate_security_contexts( self, authorized_uid: int, allowed_teams: List[int], @@ -100,7 +113,8 @@ def validate_security_contexts( if securityContext is None: continue - self.validate_security_context(authorized_uid, allowed_teams, securityContext, f"{context}[{i}]") + self.validate_security_context( + authorized_uid, allowed_teams, securityContext, f"{context}[{i}]") def validate_security_context( self, @@ -128,4 +142,4 @@ def admission_response(self, uid, allowed, message): "message": message } } - } \ No newline at end of file + } diff --git a/src/dsmlp/ext/kube.py b/src/dsmlp/ext/kube.py index 7ad80a1..46e6c03 100644 --- a/src/dsmlp/ext/kube.py +++ b/src/dsmlp/ext/kube.py @@ -18,21 +18,21 @@ def get_namespace(self, name: str) -> Namespace: api = self.get_policy_api() v1namespace: V1Namespace = api.read_namespace(name=name) metadata: V1ObjectMeta = v1namespace.metadata - + gpu_quota = 1 - if metadata.annotations is not None and GPU_LIMIT_ANNOTATION in metadata.annotations: + if metadata is not None and metadata.annotations is not None and GPU_LIMIT_ANNOTATION in metadata.annotations: gpu_quota = int(metadata.annotations[GPU_LIMIT_ANNOTATION]) - + return Namespace( name=metadata.name, labels=metadata.labels, gpu_quota=gpu_quota) - + def get_gpus_in_namespace(self, name: str) -> int: api = self.get_policy_api() V1Namespace: V1Namespace = api.read_namespace(name=name) pods = api.list_namespaced_pod(namespace=name) - + gpu_count = 0 for pod in pods.items: for container in pod.spec.containers: @@ -45,13 +45,13 @@ def get_gpus_in_namespace(self, name: str) -> int: limit = int(container.resources.limits[GPU_LABEL]) except (KeyError, AttributeError, TypeError): pass - + gpu_count += max(requested, limit) - + return gpu_count - # noinspection PyMethodMayBeStatic + def get_policy_api(self) -> CoreV1Api: try: config.load_incluster_config() diff --git a/tests/app/utils.py b/tests/app/utils.py index 47a6a0f..ec128e6 100644 --- a/tests/app/utils.py +++ b/tests/app/utils.py @@ -30,6 +30,10 @@ def gen_request(gpu_req: int = 0, gpu_lim: int = 0, low_priority: bool = False, if course is not None: labels["dsmlp/course"] = course + metadata = None + if labels != {}: + metadata = ObjectMeta(labels=labels) + sec_context = None if run_as_user is not None or run_as_group is not None or fs_group is not None or supplemental_groups is not None: sec_context = PodSecurityContext( @@ -52,7 +56,7 @@ def gen_request(gpu_req: int = 0, gpu_lim: int = 0, low_priority: bool = False, uid=uid, namespace=username, object=Object( - metadata=ObjectMeta(labels=labels), + metadata=metadata, spec=PodSpec( containers=containers, priorityClassName=p_class,