From 8191929822f67c870ab1fb9ad67d8d8496e9df42 Mon Sep 17 00:00:00 2001 From: Rockford Mankini Date: Mon, 7 Oct 2024 16:54:12 -0700 Subject: [PATCH] add test cases for tgpt validator --- src/dsmlp/app/tritongpt_validator.py | 8 +- src/dsmlp/app/validator.py | 10 +- src/dsmlp/ext/kube.py | 17 +-- tests/app/test_tgpt_validator.py | 165 +++++++++++++++++++++++++++ tests/fakes.py | 12 ++ 5 files changed, 193 insertions(+), 19 deletions(-) diff --git a/src/dsmlp/app/tritongpt_validator.py b/src/dsmlp/app/tritongpt_validator.py index 2584e2f..47f88b2 100644 --- a/src/dsmlp/app/tritongpt_validator.py +++ b/src/dsmlp/app/tritongpt_validator.py @@ -22,10 +22,12 @@ def __init__(self, kube: KubeClient, logger: Logger) -> None: def validate_pod(self, request: Request): - permitted_uids = self.kube.get_tgpt_uids() + namespace = self.kube.get_namespace(request.namespace) + + permitted_uids = self.kube.get_tgpt_uids(namespace) requested_uid = request.object.spec.securityContext.runAsUser # if request.uid is not in kube.get_tgpt_uids # return validationfailure - if requested_uid not in permitted_uids: - raise ValidationFailure(f"TritonGPT Validator: user with {permitted_uids} attempted to run a pod as {requested_uid}. Pod denied.") + if str(requested_uid) not in permitted_uids: + raise ValidationFailure(f"TritonGPT Validator: user with access to UIDs {permitted_uids} attempted to run a pod as {requested_uid}. Pod denied.") diff --git a/src/dsmlp/app/validator.py b/src/dsmlp/app/validator.py index badbeed..f719f4d 100644 --- a/src/dsmlp/app/validator.py +++ b/src/dsmlp/app/validator.py @@ -55,13 +55,17 @@ def handle_request(self, request: Request): def validate_pod(self, request: Request): + ### if tgpt-validator == enabled + ### run special tritongpt validator that gets permitted UIDs from namespace instead of sicad try: - if(self.kube.get_tgpt_label(request.namespace) == "enabled"): + namespace = self.kube.get_namespace(request.namespace) + + if(self.kube.get_tgpt_label(namespace) == "enabled"): self.logger.info("Triton GPT Mode Activated. Only running TritonGPT Validator.") TritonGPTValidator(self.kube, self.logger).validate_pod(request) return - except: - self.logger.info("Failed to evaluate TGPT label logic. Falling back on regular validator components.") + except Exception as err: + self.logger.exception(err) for component_validator in self.component_validators: component_validator.validate_pod(request) diff --git a/src/dsmlp/ext/kube.py b/src/dsmlp/ext/kube.py index 44e88b5..55b918f 100644 --- a/src/dsmlp/ext/kube.py +++ b/src/dsmlp/ext/kube.py @@ -50,23 +50,14 @@ def get_gpus_in_namespace(self, name: str) -> int: return gpu_count - def get_tgpt_label(self, name: str) -> str: - api = self.get_policy_api() - v1namespace: V1Namespace = api.read_namespace(name=name) - metadata: V1ObjectMeta = v1namespace.metadata - - if metadata is not None and metadata.labels is not None and "tgpt-validator" in metadata.labels: - return metadata.labels["tgpt-validator"] + def get_tgpt_label(self, namespace) -> str: + return namespace.labels.get("tgt-validator","") # TODO: make arbitrary function of getting namespace labels. - def get_tgpt_uids(self, name: str) -> str: - api = self.get_policy_api() - v1namespace: V1Namespace = api.read_namespace(name=name) - metadata: V1ObjectMeta = v1namespace.metadata + def get_tgpt_uids(self, namespace) -> str: # should be comma delimited, i.e. 2000,100,2,20 - if metadata is not None and metadata.labels is not None and "permitted-uids" in metadata.labels: - return metadata.labels["permitted-uids"].split(',') + return namespace.labels.get("permitted-uids", "").split(',') # noinspection PyMethodMayBeStatic diff --git a/tests/app/test_tgpt_validator.py b/tests/app/test_tgpt_validator.py index e69de29..c733ae5 100644 --- a/tests/app/test_tgpt_validator.py +++ b/tests/app/test_tgpt_validator.py @@ -0,0 +1,165 @@ +import inspect +from operator import contains +from dsmlp.app.validator import Validator +from dsmlp.plugin.awsed import ListTeamsResponse, TeamJson, UserResponse +from dsmlp.plugin.kube import Namespace +from hamcrest import assert_that, contains_inanyorder, equal_to, has_item +from tests.fakes import FakeAwsedClient, FakeLogger, FakeKubeClient + + +class TestTGPTValidator: + def setup_method(self) -> None: + self.logger = FakeLogger() + self.awsed_client = FakeAwsedClient() + self.kube_client = FakeKubeClient() + + self.awsed_client.add_user( + 'user10', UserResponse(uid=30, enrollments=[])) + self.awsed_client.add_teams('user10', ListTeamsResponse( + teams=[TeamJson(gid=1000)] + )) + + self.kube_client.add_namespace('user10', Namespace( + name='user10', labels={'k8s-sync': 'true', 'tgpt-validator': 'enabled', 'permitted-uids': '30,3000'}, gpu_quota=10)) + + self.awsed_client.add_user( + 'user100', UserResponse(uid=10, enrollments=[])) + self.awsed_client.add_teams('user10', ListTeamsResponse( + teams=[TeamJson(gid=1000)] + )) + + self.kube_client.add_namespace('user100', Namespace( + name='user100', labels={'k8s-sync': 'true', 'tgpt-validator': 'disabled', 'permitted-uids': '10'}, gpu_quota=10)) + + def test_good_request(self): + self.when_validate( + { + "request": { + "uid": "705ab4f5-6393-11e8-b7cc-42010a800002", + "namespace": "user10", + "userInfo": { + "username": "system:kube-system" + }, + "object": { + "metadata": { + "labels": {} + }, + "spec": { + "containers": [{}], + "securityContext": {"runAsUser": 30}, + }, + } + } + } + ) + + assert_that(self.logger.messages, has_item( + f"INFO Allowed request username=system:kube-system namespace=user10 uid=705ab4f5-6393-11e8-b7cc-42010a800002")) + + def test_good_request_2(self): + self.when_validate( + { + "request": { + "uid": "705ab4f5-6393-11e8-b7cc-42010a800002", + "namespace": "user10", + "userInfo": { + "username": "system:kube-system" + }, + "object": { + "metadata": { + "labels": {} + }, + "spec": { + "containers": [{}], + "securityContext": {"runAsUser": 3000}, + }, + } + } + } + ) + + assert_that(self.logger.messages, has_item( + f"INFO Allowed request username=system:kube-system namespace=user10 uid=705ab4f5-6393-11e8-b7cc-42010a800002")) + + def test_bad_request(self): + self.when_validate( + { + "request": { + "uid": "705ab4f5-6393-11e8-b7cc-42010a800002", + "namespace": "user10", + "userInfo": { + "username": "system:kube-system" + }, + "object": { + "metadata": { + "labels": {} + }, + "spec": { + "containers": [{}], + "securityContext": {"runAsUser": 300}, + }, + } + } + } + ) + + assert_that(self.logger.messages, has_item( + f"EXCEPTION TritonGPT Validator: user with access to UIDs ['30', '3000'] attempted to run a pod as 300. Pod denied.")) + + def test_good_request_not_enabled_permitted_on(self): + self.when_validate( + { + "request": { + "uid": "705ab4f5-6393-11e8-b7cc-42010a800002", + "namespace": "user100", + "userInfo": { + "username": "system:kube-system" + }, + "object": { + "metadata": { + "labels": {} + }, + "spec": { + "containers": [{}], + "securityContext": {"runAsUser": 10}, + }, + } + } + } + ) + + assert_that(self.logger.messages, has_item( + f"INFO Allowed request username=system:kube-system namespace=user100 uid=705ab4f5-6393-11e8-b7cc-42010a800002")) + + #assert_that(self.logger.messages, has_item( + #"INFO Allowed request username=user10 namespace=user10 uid=705ab4f5-6393-11e8-b7cc-42010a800002")) + + # def test_gpu_quota_request(self): + # self.awsed_client.add_user_gpu_quota('user10', 10) + # self.awsed_client.get_user_gpu_quota('user10') + + # response = self.when_validate( + # { + # "request": { + # "uid": "705ab4f5-6393-11e8-b7cc-42010a800002", + # "namespace": "user10", + # "userInfo": { + # "username": "user10" + # }, + # "object": { + # "metadata": { + # "labels": {} + # }, + # "spec": { + # "containers": [{}] + # } + # } + # } + # } + # ) + + def when_validate(self, json): + validator = Validator(self.awsed_client, self.kube_client, self.logger) + response = validator.validate_request(json) + + return response diff --git a/tests/fakes.py b/tests/fakes.py index b2f2114..3751e6f 100644 --- a/tests/fakes.py +++ b/tests/fakes.py @@ -77,6 +77,18 @@ def add_namespace(self, name: str, namespace: Namespace): def set_existing_gpus(self, name: str, gpus: int): self.existing_gpus[name] = gpus + def get_tgpt_label(self, namespace) -> str: + try: + return namespace.labels.get("tgpt-validator", "") + except KeyError: + raise UnsuccessfulRequest() + + def get_tgpt_uids(self, namespace) -> str: + try: + return namespace.labels.get("permitted-uids").split(',') + except KeyError: + raise UnsuccessfulRequest() + class FakeLogger(Logger): def __init__(self) -> None: