From f555e09e034f3ff0c75ac990a3707209b861187c Mon Sep 17 00:00:00 2001 From: D0rkKnight Date: Thu, 25 Jan 2024 17:52:24 -0800 Subject: [PATCH] Fixed issue with getting container gpus --- src/dsmlp/ext/kube.py | 9 +++++---- tests/app/test_gpu_validator.py | 34 ++++++++++++++++++++++++++++++++- 2 files changed, 38 insertions(+), 5 deletions(-) diff --git a/src/dsmlp/ext/kube.py b/src/dsmlp/ext/kube.py index c1cbd08..a3a8e28 100644 --- a/src/dsmlp/ext/kube.py +++ b/src/dsmlp/ext/kube.py @@ -35,10 +35,11 @@ def get_gpus_in_namespace(self, name: str) -> int: gpu_count = 0 for pod in pods.items: - try: - gpu_count += int(pod.spec.containers.resources.requests['GPU_LABEL']) - except KeyError: - pass + for container in pod.spec.containers: + try: + gpu_count += int(container.resources.requests[GPU_LABEL]) + except (KeyError, TypeError): + pass return gpu_count diff --git a/tests/app/test_gpu_validator.py b/tests/app/test_gpu_validator.py index 9c36f0b..ebe7e03 100644 --- a/tests/app/test_gpu_validator.py +++ b/tests/app/test_gpu_validator.py @@ -5,7 +5,7 @@ from dsmlp.plugin.kube import Namespace from hamcrest import assert_that, contains_inanyorder, equal_to, has_item from tests.fakes import FakeAwsedClient, FakeLogger, FakeKubeClient - +from dsmlp.ext.kube import DefaultKubeClient class TestValidator: def setup_method(self) -> None: @@ -185,6 +185,38 @@ def test_low_priority(self): "message": "Allowed" }}})) + def test_collect_gpus(self): + real_kube_client = DefaultKubeClient() + + from kubernetes.client import V1PodList, V1Pod, V1PodSpec, V1Container, V1ResourceRequirements + + class FakeInternalClient: + def read_namespace(self, name: str) -> Namespace: + return "namespace" + def list_namespaced_pod(self, namespace: str) -> int: + + return V1PodList( + items=[ + V1Pod( + spec=V1PodSpec( + containers=[ + V1Container( + name="container1", + resources=V1ResourceRequirements( + ) + ) + ] + ) + ) + ] + ) + + def get_policy_api(): + return FakeInternalClient() + + real_kube_client.get_policy_api = get_policy_api + real_kube_client.get_gpus_in_namespace('user10') + def when_validate(self, json): validator = Validator(self.awsed_client, self.kube_client, self.logger) response = validator.validate_request(json)