diff --git a/koku/api/report/test/util/baker_recipes.py b/koku/api/report/test/util/baker_recipes.py index f384931b5e..781a4ed501 100644 --- a/koku/api/report/test/util/baker_recipes.py +++ b/koku/api/report/test/util/baker_recipes.py @@ -15,7 +15,6 @@ from api.report.test.util.constants import OCI_CONSTANTS from api.report.test.util.constants import OCP_CONSTANTS - fake = Faker() diff --git a/koku/koku/koku_test_runner.py b/koku/koku/koku_test_runner.py index 9fd59bbfac..ce8a60561e 100644 --- a/koku/koku/koku_test_runner.py +++ b/koku/koku/koku_test_runner.py @@ -13,7 +13,6 @@ from django.db import connections from django.test.runner import DiscoverRunner from django.test.utils import get_unique_databases_and_mirrors -from django_tenants.utils import tenant_context from api.models import Customer from api.models import Provider @@ -134,6 +133,7 @@ def setup_databases(verbosity, interactive, keepdb=False, debug_sql=False, paral # OCI bakery_data_loader.load_oci_data() + for account in [("10002", "org2222222", "2222222"), ("12345", "org3333333", "3333333")]: tenant = Tenant.objects.get_or_create(schema_name=account[1])[0] tenant.save() diff --git a/koku/masu/processor/ocp/ocp_report_db_cleaner.py b/koku/masu/processor/ocp/ocp_report_db_cleaner.py index 14f92b63b7..d16131972f 100644 --- a/koku/masu/processor/ocp/ocp_report_db_cleaner.py +++ b/koku/masu/processor/ocp/ocp_report_db_cleaner.py @@ -6,15 +6,12 @@ import logging from datetime import date -from django_tenants.utils import schema_context - from api.common import log_json from koku.database import cascade_delete from koku.database import execute_delete_sql from masu.database.ocp_report_db_accessor import OCPReportDBAccessor from reporting.models import EXPIRE_MANAGED_TABLES from reporting.models import PartitionedTable -from reporting.provider.ocp.models import OCPUsageReportPeriod from reporting.provider.ocp.models import UI_SUMMARY_TABLES LOG = logging.getLogger(__name__) @@ -65,29 +62,27 @@ def purge_expired_report_data(self, expired_date=None, provider_uuid=None, simul return self.purge_expired_report_data_by_date(expired_date, simulate=simulate) usage_period_objs = accessor.get_usage_period_query_by_provider(provider_uuid) - - with schema_context(self._schema): - for usage_period in usage_period_objs.all(): - removed_items.append( - {"usage_period_id": usage_period.id, "interval_start": str(usage_period.report_period_start)} - ) - all_report_periods.append(usage_period.id) - all_cluster_ids.add(usage_period.cluster_id) - all_period_starts.add(str(usage_period.report_period_start)) - - LOG.info( - log_json( - msg="deleting provider billing data", - schema=self._schema, - provider_uuid=provider_uuid, - report_periods=all_report_periods, - cluster_ids=all_cluster_ids, - period_starts=all_period_starts, - ) + for usage_period in usage_period_objs.all(): + removed_items.append( + {"usage_period_id": usage_period.id, "interval_start": str(usage_period.report_period_start)} ) + all_report_periods.append(usage_period.id) + all_cluster_ids.add(usage_period.cluster_id) + all_period_starts.add(str(usage_period.report_period_start)) + + LOG.info( + log_json( + msg="deleting provider billing data", + schema=self._schema, + provider_uuid=provider_uuid, + report_periods=all_report_periods, + cluster_ids=all_cluster_ids, + period_starts=all_period_starts, + ) + ) - if not simulate: - cascade_delete(usage_period_objs.query.model, usage_period_objs) + if not simulate: + cascade_delete(usage_period_objs.query.model, usage_period_objs) return removed_items @@ -107,7 +102,6 @@ def purge_expired_report_data_by_date(self, expired_date, simulate=False): ] table_names.extend(UI_SUMMARY_TABLES) - with schema_context(self._schema): # Iterate over the remainder as they could involve much larger amounts of data for usage_period in all_usage_periods: removed_items.append( @@ -135,11 +129,10 @@ def purge_expired_report_data_by_date(self, expired_date, simulate=False): LOG.info(log_json(msg="deleted table partitions", count=del_count, schema=self._schema)) # Remove all data related to the report period - del_count, _ = OCPUsageReportPeriod.objects.filter(id__in=all_report_periods).delete() + cascade_delete(all_usage_periods.query.model, all_usage_periods) LOG.info( log_json( msg="deleted ocp-usage-report-periods", - count=del_count, report_periods=all_report_periods, schema=self._schema, ) diff --git a/koku/masu/processor/orchestrator.py b/koku/masu/processor/orchestrator.py index 06e3461fe7..9d78edf26a 100644 --- a/koku/masu/processor/orchestrator.py +++ b/koku/masu/processor/orchestrator.py @@ -5,6 +5,7 @@ """Report Processing Orchestrator.""" import copy import logging +from collections import defaultdict from datetime import datetime from datetime import timedelta @@ -41,7 +42,6 @@ from subs.tasks import extract_subs_data_from_reports from subs.tasks import SUBS_EXTRACTION_QUEUE - LOG = logging.getLogger(__name__) @@ -565,17 +565,37 @@ def remove_expired_report_data(self, simulate=False): """ async_results = [] + schemas = defaultdict(set) for account in Provider.objects.get_accounts(): - LOG.info("Calling remove_expired_data with account: %s", account) - async_result = remove_expired_data.delay( - schema_name=account.get("schema_name"), provider=account.get("provider_type"), simulate=simulate - ) - LOG.info( - "Expired data removal queued - schema_name: %s, Task ID: %s", - account.get("schema_name"), - str(async_result), - ) - async_results.append({"customer": account.get("customer_name"), "async_id": str(async_result)}) + # create a dict of {schema: set(provider_types)} + schemas[account.get("schema_name")].add(account.get("provider_type")) + for schema, provider_types in schemas.items(): + provider_types = list(provider_types) + if Provider.PROVIDER_OCP in provider_types: + # move OCP to the end of the list because its ForeignKeys are complicated and these should be cleaned + # up after the cloud providers + provider_types.remove(Provider.PROVIDER_OCP) + provider_types.append(Provider.PROVIDER_OCP) + for provider_type in provider_types: + LOG.info( + log_json( + "remove_expired_report_data", + msg="calling remove_expired_data", + schema=schema, + provider_type=provider_type, + ) + ) + async_result = remove_expired_data.delay(schema_name=schema, provider=provider_type, simulate=simulate) + LOG.info( + log_json( + "remove_expired_report_data", + msg="expired data removal queued", + schema=schema, + provider_type=provider_type, + task_id=str(async_result), + ) + ) + async_results.append({"schema": schema, "provider_type": provider_type, "async_id": str(async_result)}) return async_results def remove_expired_trino_partitions(self, simulate=False): diff --git a/koku/masu/test/api/test_expired_data.py b/koku/masu/test/api/test_expired_data.py index 3b6f566b1e..e20749fda8 100644 --- a/koku/masu/test/api/test_expired_data.py +++ b/koku/masu/test/api/test_expired_data.py @@ -22,7 +22,9 @@ class ExpiredDataTest(TestCase): @patch.object(Orchestrator, "remove_expired_report_data") def test_get_expired_data(self, mock_orchestrator, _, mock_service): """Test the GET expired_data endpoint.""" - mock_response = [{"customer": "org1234567", "async_id": "f9eb2ce7-4564-4509-aecc-1200958c07cf"}] + mock_response = [ + {"schema": "org1234567", "provider_type": "OCP", "async_id": "f9eb2ce7-4564-4509-aecc-1200958c07cf"} + ] expected_key = "Async jobs for expired data removal (simulated)" mock_orchestrator.return_value = mock_response response = self.client.get(reverse("expired_data")) @@ -38,7 +40,9 @@ def test_get_expired_data(self, mock_orchestrator, _, mock_service): @patch.object(Orchestrator, "remove_expired_report_data") def test_del_expired_data(self, mock_orchestrator, mock_debug, _, mock_service): """Test the DELETE expired_data endpoint.""" - mock_response = [{"customer": "org1234567", "async_id": "f9eb2ce7-4564-4509-aecc-1200958c07cf"}] + mock_response = [ + {"schema": "org1234567", "provider_type": "OCP", "async_id": "f9eb2ce7-4564-4509-aecc-1200958c07cf"} + ] expected_key = "Async jobs for expired data removal" mock_orchestrator.return_value = mock_response @@ -54,7 +58,9 @@ def test_del_expired_data(self, mock_orchestrator, mock_debug, _, mock_service): @patch.object(Orchestrator, "remove_expired_trino_partitions") def test_get_expired_partitions(self, mock_orchestrator, _, mock_service): """Test the GET expired_trino_paritions endpoint.""" - mock_response = [{"customer": "org1234567", "async_id": "f9eb2ce7-4564-4509-aecc-1200958c07cf"}] + mock_response = [ + {"schema": "org1234567", "provider_type": "OCP", "async_id": "f9eb2ce7-4564-4509-aecc-1200958c07cf"} + ] expected_key = "Async jobs for expired paritions removal (simulated)" mock_orchestrator.return_value = mock_response response = self.client.get(reverse("expired_trino_partitions")) @@ -71,7 +77,9 @@ def test_get_expired_partitions(self, mock_orchestrator, _, mock_service): @patch.object(Orchestrator, "remove_expired_trino_partitions") def test_del_expired_partitions(self, mock_orchestrator, mock_debug, _, mock_service): """Test the DELETE expired_trino_partitions endpoint.""" - mock_response = [{"customer": "org1234567", "async_id": "f9eb2ce7-4564-4509-aecc-1200958c07cf"}] + mock_response = [ + {"schema": "org1234567", "provider_type": "OCP", "async_id": "f9eb2ce7-4564-4509-aecc-1200958c07cf"} + ] expected_key = "Async jobs for expired paritions removal" mock_orchestrator.return_value = mock_response diff --git a/koku/masu/test/processor/test_orchestrator.py b/koku/masu/test/processor/test_orchestrator.py index 497a3f426a..30e79dc153 100644 --- a/koku/masu/test/processor/test_orchestrator.py +++ b/koku/masu/test/processor/test_orchestrator.py @@ -112,25 +112,15 @@ def test_unleash_is_cloud_source_processing_disabled(self, mock_processing_check self.assertIn(expected_result, captured_logs.output[0]) @patch("masu.processor.worker_cache.CELERY_INSPECT") - @patch.object(ExpiredDataRemover, "remove") - @patch("masu.processor.orchestrator.remove_expired_data.apply_async", return_value=True) - def test_remove_expired_report_data(self, mock_task, mock_remover, mock_inspect): + @patch("masu.processor.orchestrator.remove_expired_data.delay", return_value=True) + def test_remove_expired_report_data(self, mock_task, mock_inspect): """Test removing expired report data.""" - expected_results = [{"account_payer_id": "999999999", "billing_period_start": "2018-06-24 15:47:33.052509"}] - mock_remover.return_value = expected_results - - expected = ( - "INFO:masu.processor.orchestrator:Expired data removal queued - schema_name: org1234567, Task ID: {}" - ) - # unset disabling all logging below CRITICAL from masu/__init__.py - logging.disable(logging.NOTSET) - with self.assertLogs("masu.processor.orchestrator", level="INFO") as logger: - orchestrator = Orchestrator() - results = orchestrator.remove_expired_report_data() - self.assertTrue(results) - self.assertEqual(len(results), 8) - async_id = results.pop().get("async_id") - self.assertIn(expected.format(async_id), logger.output) + orchestrator = Orchestrator() + results = orchestrator.remove_expired_report_data() + self.assertTrue(results) + self.assertEqual( + len(results), Provider.objects.order_by().distinct("type").count() + ) # number of distinct provider types @patch("masu.processor.worker_cache.CELERY_INSPECT") @patch.object(ExpiredDataRemover, "remove")