Skip to content

Commit

Permalink
Fixed issue with getting container gpus
Browse files Browse the repository at this point in the history
  • Loading branch information
shouhanzen committed Jan 26, 2024
1 parent c72acfb commit f555e09
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 5 deletions.
9 changes: 5 additions & 4 deletions src/dsmlp/ext/kube.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,11 @@ def get_gpus_in_namespace(self, name: str) -> int:

gpu_count = 0
for pod in pods.items:
try:
gpu_count += int(pod.spec.containers.resources.requests['GPU_LABEL'])
except KeyError:
pass
for container in pod.spec.containers:
try:
gpu_count += int(container.resources.requests[GPU_LABEL])
except (KeyError, TypeError):
pass

return gpu_count

Expand Down
34 changes: 33 additions & 1 deletion tests/app/test_gpu_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dsmlp.plugin.kube import Namespace
from hamcrest import assert_that, contains_inanyorder, equal_to, has_item
from tests.fakes import FakeAwsedClient, FakeLogger, FakeKubeClient

from dsmlp.ext.kube import DefaultKubeClient

class TestValidator:
def setup_method(self) -> None:
Expand Down Expand Up @@ -185,6 +185,38 @@ def test_low_priority(self):
"message": "Allowed"
}}}))

def test_collect_gpus(self):
real_kube_client = DefaultKubeClient()

from kubernetes.client import V1PodList, V1Pod, V1PodSpec, V1Container, V1ResourceRequirements

class FakeInternalClient:
def read_namespace(self, name: str) -> Namespace:
return "namespace"
def list_namespaced_pod(self, namespace: str) -> int:

return V1PodList(
items=[
V1Pod(
spec=V1PodSpec(
containers=[
V1Container(
name="container1",
resources=V1ResourceRequirements(
)
)
]
)
)
]
)

def get_policy_api():
return FakeInternalClient()

real_kube_client.get_policy_api = get_policy_api
real_kube_client.get_gpus_in_namespace('user10')

def when_validate(self, json):
validator = Validator(self.awsed_client, self.kube_client, self.logger)
response = validator.validate_request(json)
Expand Down

0 comments on commit f555e09

Please sign in to comment.