Skip to content

Commit

Permalink
[COST-5851] one expiration task per provider_type (#5432)
Browse files Browse the repository at this point in the history
  • Loading branch information
maskarb authored Jan 28, 2025
1 parent 7a51af4 commit 162a221
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 62 deletions.
1 change: 0 additions & 1 deletion koku/api/report/test/util/baker_recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from api.report.test.util.constants import OCI_CONSTANTS
from api.report.test.util.constants import OCP_CONSTANTS


fake = Faker()


Expand Down
2 changes: 1 addition & 1 deletion koku/koku/koku_test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from django.db import connections
from django.test.runner import DiscoverRunner
from django.test.utils import get_unique_databases_and_mirrors
from django_tenants.utils import tenant_context

from api.models import Customer
from api.models import Provider
Expand Down Expand Up @@ -134,6 +133,7 @@ def setup_databases(verbosity, interactive, keepdb=False, debug_sql=False, paral

# OCI
bakery_data_loader.load_oci_data()

for account in [("10002", "org2222222", "2222222"), ("12345", "org3333333", "3333333")]:
tenant = Tenant.objects.get_or_create(schema_name=account[1])[0]
tenant.save()
Expand Down
47 changes: 20 additions & 27 deletions koku/masu/processor/ocp/ocp_report_db_cleaner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,12 @@
import logging
from datetime import date

from django_tenants.utils import schema_context

from api.common import log_json
from koku.database import cascade_delete
from koku.database import execute_delete_sql
from masu.database.ocp_report_db_accessor import OCPReportDBAccessor
from reporting.models import EXPIRE_MANAGED_TABLES
from reporting.models import PartitionedTable
from reporting.provider.ocp.models import OCPUsageReportPeriod
from reporting.provider.ocp.models import UI_SUMMARY_TABLES

LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -65,29 +62,27 @@ def purge_expired_report_data(self, expired_date=None, provider_uuid=None, simul
return self.purge_expired_report_data_by_date(expired_date, simulate=simulate)

usage_period_objs = accessor.get_usage_period_query_by_provider(provider_uuid)

with schema_context(self._schema):
for usage_period in usage_period_objs.all():
removed_items.append(
{"usage_period_id": usage_period.id, "interval_start": str(usage_period.report_period_start)}
)
all_report_periods.append(usage_period.id)
all_cluster_ids.add(usage_period.cluster_id)
all_period_starts.add(str(usage_period.report_period_start))

LOG.info(
log_json(
msg="deleting provider billing data",
schema=self._schema,
provider_uuid=provider_uuid,
report_periods=all_report_periods,
cluster_ids=all_cluster_ids,
period_starts=all_period_starts,
)
for usage_period in usage_period_objs.all():
removed_items.append(
{"usage_period_id": usage_period.id, "interval_start": str(usage_period.report_period_start)}
)
all_report_periods.append(usage_period.id)
all_cluster_ids.add(usage_period.cluster_id)
all_period_starts.add(str(usage_period.report_period_start))

LOG.info(
log_json(
msg="deleting provider billing data",
schema=self._schema,
provider_uuid=provider_uuid,
report_periods=all_report_periods,
cluster_ids=all_cluster_ids,
period_starts=all_period_starts,
)
)

if not simulate:
cascade_delete(usage_period_objs.query.model, usage_period_objs)
if not simulate:
cascade_delete(usage_period_objs.query.model, usage_period_objs)

return removed_items

Expand All @@ -107,7 +102,6 @@ def purge_expired_report_data_by_date(self, expired_date, simulate=False):
]
table_names.extend(UI_SUMMARY_TABLES)

with schema_context(self._schema):
# Iterate over the remainder as they could involve much larger amounts of data
for usage_period in all_usage_periods:
removed_items.append(
Expand Down Expand Up @@ -135,11 +129,10 @@ def purge_expired_report_data_by_date(self, expired_date, simulate=False):
LOG.info(log_json(msg="deleted table partitions", count=del_count, schema=self._schema))

# Remove all data related to the report period
del_count, _ = OCPUsageReportPeriod.objects.filter(id__in=all_report_periods).delete()
cascade_delete(all_usage_periods.query.model, all_usage_periods)
LOG.info(
log_json(
msg="deleted ocp-usage-report-periods",
count=del_count,
report_periods=all_report_periods,
schema=self._schema,
)
Expand Down
42 changes: 31 additions & 11 deletions koku/masu/processor/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""Report Processing Orchestrator."""
import copy
import logging
from collections import defaultdict
from datetime import datetime
from datetime import timedelta

Expand Down Expand Up @@ -41,7 +42,6 @@
from subs.tasks import extract_subs_data_from_reports
from subs.tasks import SUBS_EXTRACTION_QUEUE


LOG = logging.getLogger(__name__)


Expand Down Expand Up @@ -565,17 +565,37 @@ def remove_expired_report_data(self, simulate=False):
"""
async_results = []
schemas = defaultdict(set)
for account in Provider.objects.get_accounts():
LOG.info("Calling remove_expired_data with account: %s", account)
async_result = remove_expired_data.delay(
schema_name=account.get("schema_name"), provider=account.get("provider_type"), simulate=simulate
)
LOG.info(
"Expired data removal queued - schema_name: %s, Task ID: %s",
account.get("schema_name"),
str(async_result),
)
async_results.append({"customer": account.get("customer_name"), "async_id": str(async_result)})
# create a dict of {schema: set(provider_types)}
schemas[account.get("schema_name")].add(account.get("provider_type"))
for schema, provider_types in schemas.items():
provider_types = list(provider_types)
if Provider.PROVIDER_OCP in provider_types:
# move OCP to the end of the list because its ForeignKeys are complicated and these should be cleaned
# up after the cloud providers
provider_types.remove(Provider.PROVIDER_OCP)
provider_types.append(Provider.PROVIDER_OCP)
for provider_type in provider_types:
LOG.info(
log_json(
"remove_expired_report_data",
msg="calling remove_expired_data",
schema=schema,
provider_type=provider_type,
)
)
async_result = remove_expired_data.delay(schema_name=schema, provider=provider_type, simulate=simulate)
LOG.info(
log_json(
"remove_expired_report_data",
msg="expired data removal queued",
schema=schema,
provider_type=provider_type,
task_id=str(async_result),
)
)
async_results.append({"schema": schema, "provider_type": provider_type, "async_id": str(async_result)})
return async_results

def remove_expired_trino_partitions(self, simulate=False):
Expand Down
16 changes: 12 additions & 4 deletions koku/masu/test/api/test_expired_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ class ExpiredDataTest(TestCase):
@patch.object(Orchestrator, "remove_expired_report_data")
def test_get_expired_data(self, mock_orchestrator, _, mock_service):
"""Test the GET expired_data endpoint."""
mock_response = [{"customer": "org1234567", "async_id": "f9eb2ce7-4564-4509-aecc-1200958c07cf"}]
mock_response = [
{"schema": "org1234567", "provider_type": "OCP", "async_id": "f9eb2ce7-4564-4509-aecc-1200958c07cf"}
]
expected_key = "Async jobs for expired data removal (simulated)"
mock_orchestrator.return_value = mock_response
response = self.client.get(reverse("expired_data"))
Expand All @@ -38,7 +40,9 @@ def test_get_expired_data(self, mock_orchestrator, _, mock_service):
@patch.object(Orchestrator, "remove_expired_report_data")
def test_del_expired_data(self, mock_orchestrator, mock_debug, _, mock_service):
"""Test the DELETE expired_data endpoint."""
mock_response = [{"customer": "org1234567", "async_id": "f9eb2ce7-4564-4509-aecc-1200958c07cf"}]
mock_response = [
{"schema": "org1234567", "provider_type": "OCP", "async_id": "f9eb2ce7-4564-4509-aecc-1200958c07cf"}
]
expected_key = "Async jobs for expired data removal"
mock_orchestrator.return_value = mock_response

Expand All @@ -54,7 +58,9 @@ def test_del_expired_data(self, mock_orchestrator, mock_debug, _, mock_service):
@patch.object(Orchestrator, "remove_expired_trino_partitions")
def test_get_expired_partitions(self, mock_orchestrator, _, mock_service):
"""Test the GET expired_trino_paritions endpoint."""
mock_response = [{"customer": "org1234567", "async_id": "f9eb2ce7-4564-4509-aecc-1200958c07cf"}]
mock_response = [
{"schema": "org1234567", "provider_type": "OCP", "async_id": "f9eb2ce7-4564-4509-aecc-1200958c07cf"}
]
expected_key = "Async jobs for expired paritions removal (simulated)"
mock_orchestrator.return_value = mock_response
response = self.client.get(reverse("expired_trino_partitions"))
Expand All @@ -71,7 +77,9 @@ def test_get_expired_partitions(self, mock_orchestrator, _, mock_service):
@patch.object(Orchestrator, "remove_expired_trino_partitions")
def test_del_expired_partitions(self, mock_orchestrator, mock_debug, _, mock_service):
"""Test the DELETE expired_trino_partitions endpoint."""
mock_response = [{"customer": "org1234567", "async_id": "f9eb2ce7-4564-4509-aecc-1200958c07cf"}]
mock_response = [
{"schema": "org1234567", "provider_type": "OCP", "async_id": "f9eb2ce7-4564-4509-aecc-1200958c07cf"}
]
expected_key = "Async jobs for expired paritions removal"
mock_orchestrator.return_value = mock_response

Expand Down
26 changes: 8 additions & 18 deletions koku/masu/test/processor/test_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,25 +112,15 @@ def test_unleash_is_cloud_source_processing_disabled(self, mock_processing_check
self.assertIn(expected_result, captured_logs.output[0])

@patch("masu.processor.worker_cache.CELERY_INSPECT")
@patch.object(ExpiredDataRemover, "remove")
@patch("masu.processor.orchestrator.remove_expired_data.apply_async", return_value=True)
def test_remove_expired_report_data(self, mock_task, mock_remover, mock_inspect):
@patch("masu.processor.orchestrator.remove_expired_data.delay", return_value=True)
def test_remove_expired_report_data(self, mock_task, mock_inspect):
"""Test removing expired report data."""
expected_results = [{"account_payer_id": "999999999", "billing_period_start": "2018-06-24 15:47:33.052509"}]
mock_remover.return_value = expected_results

expected = (
"INFO:masu.processor.orchestrator:Expired data removal queued - schema_name: org1234567, Task ID: {}"
)
# unset disabling all logging below CRITICAL from masu/__init__.py
logging.disable(logging.NOTSET)
with self.assertLogs("masu.processor.orchestrator", level="INFO") as logger:
orchestrator = Orchestrator()
results = orchestrator.remove_expired_report_data()
self.assertTrue(results)
self.assertEqual(len(results), 8)
async_id = results.pop().get("async_id")
self.assertIn(expected.format(async_id), logger.output)
orchestrator = Orchestrator()
results = orchestrator.remove_expired_report_data()
self.assertTrue(results)
self.assertEqual(
len(results), Provider.objects.order_by().distinct("type").count()
) # number of distinct provider types

@patch("masu.processor.worker_cache.CELERY_INSPECT")
@patch.object(ExpiredDataRemover, "remove")
Expand Down

0 comments on commit 162a221

Please sign in to comment.