Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gpu quota #4

Merged
merged 14 commits into from
Mar 12, 2024
4 changes: 3 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,7 @@
],
"files.exclude": {
"**/__pycache__": true
}
},
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}
621 changes: 621 additions & 0 deletions ref.json

Large diffs are not rendered by default.

Empty file added ref.txt
Empty file.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ PyHamcrest
requests_mock
dataclasses-json
python-dotenv
pytest
git+https://github.com/ucsd-ets/[email protected]
2 changes: 1 addition & 1 deletion src/dsmlp/admission_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def create_app(test_config=None):
logging.getLogger('waitress').setLevel(logging.INFO)
logging.getLogger('dsmlp').setLevel(logging.DEBUG)
logger = PythonLogger(None)
validator = Validator(factory.awsed_client, logger)
validator = Validator(factory.awsed_client, factory.kube_client, logger)

@app.route('/validate', methods=['POST'])
def validate_request():
Expand Down
3 changes: 3 additions & 0 deletions src/dsmlp/app/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
GPU_LABEL = "nvidia.com/gpu"
GPU_LIMIT_ANNOTATION = 'gpu-limit'
LOW_PRIORITY_CLASS = "low"
56 changes: 56 additions & 0 deletions src/dsmlp/app/gpu_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from dataclasses import dataclass
import json
from typing import List, Optional

from dataclasses_json import dataclass_json
from dsmlp.plugin.awsed import AwsedClient, UnsuccessfulRequest
from dsmlp.plugin.console import Console
from dsmlp.plugin.course import ConfigProvider
from dsmlp.plugin.kube import KubeClient, NotFound
import jsonify

from dsmlp.plugin.logger import Logger
from dsmlp.app.types import *
from dsmlp.app.config import *


class GPUValidator(ComponentValidator):

def __init__(self, kube: KubeClient, logger: Logger) -> None:
self.kube = kube
self.logger = logger

def validate_pod(self, request: Request):
"""
Validate pods for namespaces with the 'k8s-sync' label
"""

# Low priority pods pass through
priority = request.object.spec.priorityClassName
if priority is not None and priority == LOW_PRIORITY_CLASS:
return

namespace = self.kube.get_namespace(request.namespace)
curr_gpus = self.kube.get_gpus_in_namespace(request.namespace)

utilized_gpus = 0
for container in request.object.spec.containers:
requested, limit = 0, 0
try:
requested = int(container.resources.requests[GPU_LABEL])
except (KeyError, AttributeError, TypeError):
pass
try:
limit = int(container.resources.limits[GPU_LABEL])
except (KeyError, AttributeError, TypeError):
pass

utilized_gpus += max(requested, limit)

# Short circuit if no GPUs requested (permits overcap)
if utilized_gpus == 0:
return

if utilized_gpus + curr_gpus > namespace.gpu_quota:
raise ValidationFailure(
f"GPU quota exceeded. Wanted {utilized_gpus} but with {curr_gpus} already in use, the quota of {namespace.gpu_quota} would be exceeded.")
145 changes: 145 additions & 0 deletions src/dsmlp/app/id_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from dataclasses import dataclass
import json
from typing import List, Optional

from dataclasses_json import dataclass_json
from dsmlp.plugin.awsed import AwsedClient, UnsuccessfulRequest
from dsmlp.plugin.console import Console
from dsmlp.plugin.course import ConfigProvider
from dsmlp.plugin.kube import KubeClient, NotFound
import jsonify

from dsmlp.plugin.logger import Logger
from dsmlp.app.types import *


class IDValidator(ComponentValidator):

def __init__(self, awsed: AwsedClient, logger: Logger) -> None:
self.awsed = awsed
self.logger = logger

def validate_pod(self, request: Request):
"""
Validate pods for namespaces with the 'k8s-sync' label
"""
username = request.namespace
# namespace = self.kube.get_namespace(request.namespace)

# if 'k8s-sync' in namespace.labels:
user = self.awsed.describe_user(username)
if not user:
raise ValidationFailure(
f"namespace: no AWSEd user found with username {username}")
allowed_uid = user.uid
allowed_courses = user.enrollments

team_response = self.awsed.list_user_teams(username)
allowed_gids = [team.gid for team in team_response.teams]
allowed_gids.append(0)
allowed_gids.append(100)

metadata = request.object.metadata
spec = request.object.spec

if metadata is not None and metadata.labels is not None:
self.validate_course_enrollment(allowed_courses, metadata.labels)

self.validate_pod_security_context(
allowed_uid, allowed_gids, spec.securityContext)
self.validate_containers(allowed_uid, allowed_gids, spec)

def validate_course_enrollment(self, allowed_courses: List[str], labels: Dict[str, str]):
if not 'dsmlp/course' in labels:
return
if not labels['dsmlp/course'] in allowed_courses:
raise ValidationFailure(
f"metadata.labels: dsmlp/course must be in range {allowed_courses}")

def validate_pod_security_context(
self,
authorized_uid: int,
allowed_teams: List[int],
securityContext: PodSecurityContext):

if securityContext is None:
return

if securityContext.runAsUser is not None and authorized_uid != securityContext.runAsUser:
raise ValidationFailure(
f"spec.securityContext: uid must be in range [{authorized_uid}]")

if securityContext.runAsGroup is not None and securityContext.runAsGroup not in allowed_teams:
raise ValidationFailure(
f"spec.securityContext: gid must be in range {allowed_teams}")

if securityContext.fsGroup is not None and securityContext.fsGroup not in allowed_teams:
raise ValidationFailure(
f"spec.securityContext: gid must be in range {allowed_teams}")

if securityContext.supplementalGroups is not None:
for sgroup in securityContext.supplementalGroups:
if not sgroup in allowed_teams:
raise ValidationFailure(
f"spec.securityContext: gid must be in range {allowed_teams}")

def validate_containers(
self,
authorized_uid: int,
allowed_teams: List[int],
spec: PodSpec
):
"""
Validate the security context of containers and initContainers
"""
self.validate_security_contexts(
authorized_uid, allowed_teams, spec.containers, "containers")
self.validate_security_contexts(
authorized_uid, allowed_teams, spec.initContainers, "initContainers")

def validate_security_contexts(
self, authorized_uid: int, allowed_teams: List[int],
containers: List[Container],
context: str):
"""
Validate the security context of a container.
"""

if containers is None:
return

for i, container in enumerate(containers):
securityContext = container.securityContext
if securityContext is None:
continue

self.validate_security_context(
authorized_uid, allowed_teams, securityContext, f"{context}[{i}]")

def validate_security_context(
self,
authorized_uid: int,
allowed_teams: List[int],
securityContext: SecurityContext,
context: str):

if securityContext.runAsUser is not None and authorized_uid != securityContext.runAsUser:
raise ValidationFailure(
f"spec.{context}.securityContext: uid must be in range [{authorized_uid}]")

if securityContext.runAsGroup is not None and securityContext.runAsGroup not in allowed_teams:
raise ValidationFailure(
f"spec.{context}.securityContext: gid must be in range {allowed_teams}")

def admission_response(self, uid, allowed, message):
return {
"apiVersion": "admission.k8s.io/v1",
"kind": "AdmissionReview",
"response": {
"uid": uid,
"allowed": allowed,
"status": {
"message": message
}
}
}
85 changes: 85 additions & 0 deletions src/dsmlp/app/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@

from dataclasses import dataclass
from typing import List, Optional, Dict
from dataclasses_json import dataclass_json
from abc import ABCMeta, abstractmethod

@dataclass_json
@dataclass
class SecurityContext:
"""Each Container has a SecurityContext"""
runAsUser: Optional[int] = None
runAsGroup: Optional[int] = None

@dataclass_json
@dataclass
class ResourceRequirements:
requests: Optional[Dict[str, int]] = None
limits: Optional[Dict[str, int]] = None

@dataclass_json
@dataclass
class Container:
securityContext: Optional[SecurityContext] = None
resources: Optional[ResourceRequirements] = None

@dataclass_json
@dataclass
class PodSecurityContext:
"""Each Pod has a SecurityContext"""
runAsUser: Optional[int] = None
runAsGroup: Optional[int] = None
fsGroup: Optional[int] = None
supplementalGroups: Optional[List[int]] = None


@dataclass_json
@dataclass
class PodSpec:
containers: List[Container]
initContainers: Optional[List[Container]] = None
securityContext: Optional[PodSecurityContext] = None
priorityClassName: Optional[str] = None

@dataclass_json
@dataclass
class ObjectMeta:
labels: Dict[str, str]


@dataclass_json
@dataclass
class Object:
metadata: ObjectMeta
spec: PodSpec


@dataclass_json
@dataclass
class UserInfo:
username: str


@dataclass_json
@dataclass
class Request:
uid: str
namespace: str
object: Object
userInfo: UserInfo


@dataclass_json
@dataclass
class AdmissionReview:
request: Request

class ValidationFailure(Exception):
def __init__(self, message: str) -> None:
self.message = message
super().__init__(self.message)

class ComponentValidator:
@abstractmethod
def validate_pod(self, request: Request):
pass
Loading
Loading