diff --git a/koku/koku/middleware.py b/koku/koku/middleware.py index a53d962d2f..cc5c9ffeaf 100644 --- a/koku/koku/middleware.py +++ b/koku/koku/middleware.py @@ -29,6 +29,7 @@ from django_tenants.middleware import TenantMainMiddleware from prometheus_client import Counter +from api.common import log_json from api.common import RH_IDENTITY_HEADER from api.common.pagination import EmptyResultsSetPagination from api.iam.models import Customer @@ -247,10 +248,11 @@ class IdentityHeaderMiddleware(MiddlewareMixin): def create_customer(account, org_id, request_method): """Create a customer. Args: - account (str): The account identifier - org_id (str): The org_id identifier + account (str): The account identifier. + org_id (str): The org_id identifier. + request_method (str): The HTTP request method. Returns: - (Customer) The created customer + Customer : The created or retrieved customer. """ try: with transaction.atomic(): @@ -260,7 +262,16 @@ def create_customer(account, org_id, request_method): customer.save() UNIQUE_ACCOUNT_COUNTER.inc() LOG.info("Created new customer from account_id %s and org_id %s.", account, org_id) - except IntegrityError: + + except IntegrityError as err: + LOG.warning( + log_json( + msg="IntegrityError when creating customer. Attempting to fetch existing record", + account=account, + org_id=org_id, + exc_info=err, + ) + ) customer = Customer.objects.filter(org_id=org_id).get() return customer @@ -426,7 +437,7 @@ def process_response(self, request, response): class RequestTimingMiddleware(MiddlewareMixin): """A class to add total time taken to a request/response.""" - def process_request(self, request): # noqa: C901 + def process_request(self, request): """Process request to add start time. Args: request (object): The request object diff --git a/koku/koku/test_middleware.py b/koku/koku/test_middleware.py index f62512f780..65ee4b6909 100644 --- a/koku/koku/test_middleware.py +++ b/koku/koku/test_middleware.py @@ -15,6 +15,7 @@ from cachetools import TTLCache from django.core.cache import caches from django.core.exceptions import PermissionDenied +from django.db.utils import IntegrityError from django.db.utils import OperationalError from django.http import JsonResponse from django.test.utils import modify_settings @@ -482,6 +483,33 @@ def test_process_service_account_identity(self): middleware = IdentityHeaderMiddleware(self.mock_get_response) middleware.process_request(mock_request) + @patch("api.iam.models.Customer.save") + def test_create_customer(self, mock_save): + """Test creating a customer.""" + + mock_save.return_value = None + customer = IdentityHeaderMiddleware.create_customer("test_account", "test_org", "POST") + + self.assertIsNotNone(customer) + self.assertEqual(customer.account_id, "test_account") + mock_save.assert_called_once() + + @patch("api.iam.models.Customer.objects.filter") + @patch("api.iam.models.Customer.save", side_effect=IntegrityError) + def test_create_customer_integrity_error_existing_customer(self, mock_save, mock_filter): + """Test fetching an existing customer when an IntegrityError occurs.""" + + mock_query_set = MagicMock() + mock_filter.return_value = mock_query_set + mock_query_set.get.return_value = MagicMock(account_id="test_account", org_id="test_org") + + customer = IdentityHeaderMiddleware.create_customer("test_account", "test_org", "POST") + + self.assertIsNotNone(customer) + mock_save.assert_called_once() + self.assertEqual(customer.org_id, "test_org") + mock_filter.assert_called_once_with(org_id="test_org") + class RequestTimingMiddlewareTest(IamTestCase): """Tests against the koku tenant middleware.""" diff --git a/koku/masu/management/commands/migrate_trino_tables.py b/koku/masu/management/commands/migrate_trino_tables.py index 3fd59bfe5e..9905d3588c 100644 --- a/koku/masu/management/commands/migrate_trino_tables.py +++ b/koku/masu/management/commands/migrate_trino_tables.py @@ -343,7 +343,7 @@ def run_trino_sql(sql, schema=None) -> list[t.Optional[list[int]]]: return cur.fetchall() except TrinoExternalError as err: if err.error_name == "HIVE_METASTORE_ERROR" and n < (retries): - LOG.warn( + LOG.warning( f"{err.message}. Attempt number {attempt} of {retries} failed. " f"Trying {remaining_retries} more time{'s' if remaining_retries > 1 else ''} " f"after waiting {wait:.2f}s."