From bdd37735a8cea295d73693143c2010453e587a35 Mon Sep 17 00:00:00 2001 From: trn024 Date: Thu, 29 Aug 2024 14:53:32 -0700 Subject: [PATCH] The key should nvidia.com/gpu instead of GPU when fetching awsed gpu quota --- src/dsmlp/ext/awsed.py | 3 ++- tests/app/test_gpu_validator.py | 6 +++--- tests/fakes.py | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/dsmlp/ext/awsed.py b/src/dsmlp/ext/awsed.py index b5daf5f..788f7a9 100644 --- a/src/dsmlp/ext/awsed.py +++ b/src/dsmlp/ext/awsed.py @@ -8,6 +8,7 @@ import awsed.client import awsed.types import logging +from dsmlp.plugin.logger import Logger # added logging to check if API has an error getting GPU quota logging.basicConfig(level=logging.ERROR, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') @@ -38,7 +39,7 @@ def get_user_gpu_quota(self, username: str) -> UserQuotaResponse: usrGpuQuota = self.client.get_user_quota(username) if not usrGpuQuota: return None - gpu_quota = usrGpuQuota.get("gpu", 0) + gpu_quota = usrGpuQuota.get("nvidia.com/gpu", 0) quota = Quota(user=username, resources=gpu_quota) return UserQuotaResponse(quota=quota) except Exception as e: diff --git a/tests/app/test_gpu_validator.py b/tests/app/test_gpu_validator.py index df09044..f5754e8 100644 --- a/tests/app/test_gpu_validator.py +++ b/tests/app/test_gpu_validator.py @@ -92,7 +92,7 @@ def try_validate(self, json, expected: bool, message: str = None): # Test correct response for get_user_gpu_quota method def test_awsed_gpu_quota_correct_response(self): - self.awsed_client.assign_user_gpu_quota('user11', {"gpu": 5}) + self.awsed_client.assign_user_gpu_quota('user11', {"nvidia.com/gpu": 5}) user_gpu_quota = self.awsed_client.get_user_gpu_quota('user11') assert_that(user_gpu_quota, equal_to(5)) @@ -123,7 +123,7 @@ def test_gpu_quota_client_priority(self): self.kube_client.set_existing_gpus('user11', 3) # add awsed quota - self.awsed_client.assign_user_gpu_quota('user11', {"gpu": 6}) + self.awsed_client.assign_user_gpu_quota('user11', {"nvidia.com/gpu": 6}) self.try_validate( gen_request(gpu_req=6, username='user11'), expected=False, message="GPU quota exceeded. Wanted 6 but with 3 already in use, the quota of 6 would be exceeded." ) @@ -133,7 +133,7 @@ def test_gpu_quota_client_priority2(self): self.kube_client.add_namespace('user11', Namespace( name='user11', labels={'k8s-sync': 'true'}, gpu_quota=12)) # add awsed quota - self.awsed_client.assign_user_gpu_quota('user11', {"gpu": 18}) + self.awsed_client.assign_user_gpu_quota('user11', {"nvidia.com/gpu": 18}) # set existing gpu = kube client quota self.kube_client.set_existing_gpus('user11', 12) diff --git a/tests/fakes.py b/tests/fakes.py index b2f2114..d977550 100644 --- a/tests/fakes.py +++ b/tests/fakes.py @@ -34,7 +34,7 @@ def describe_user(self, username: str) -> UserResponse: def get_user_gpu_quota(self, username: str) -> int: try: user_quota_response = self.user_quota[username] - return user_quota_response.quota.resources.get("gpu", 0) + return user_quota_response.quota.resources.get("nvidia.com/gpu", 0) except KeyError: return 0