diff --git a/koku/masu/database/ocp_report_db_accessor.py b/koku/masu/database/ocp_report_db_accessor.py index 5e730a052d..e10d3e9700 100644 --- a/koku/masu/database/ocp_report_db_accessor.py +++ b/koku/masu/database/ocp_report_db_accessor.py @@ -17,7 +17,6 @@ from django.db.models import F from django.db.models import Value from django.db.models.functions import Coalesce -from django_tenants.utils import schema_context from trino.exceptions import TrinoExternalError from api.common import log_json @@ -25,7 +24,6 @@ from api.provider.models import Provider from koku.database import SQLScriptAtomicExecutorMixin from masu.config import Config -from masu.database import AWS_CUR_TABLE_MAP from masu.database import OCP_REPORT_TABLE_MAP from masu.database.report_db_accessor_base import ReportDBAccessorBase from masu.util.common import filter_dictionary @@ -74,46 +72,24 @@ def __init__(self, schema): super().__init__(schema) self._datetime_format = Config.OCP_DATETIME_STR_FORMAT self._table_map = OCP_REPORT_TABLE_MAP - self._aws_table_map = AWS_CUR_TABLE_MAP @property def line_item_daily_summary_table(self): return OCPUsageLineItemDailySummary - def get_current_usage_period(self, provider_uuid): - """Get the most recent usage report period object.""" - with schema_context(self.schema): - return ( - OCPUsageReportPeriod.objects.filter(provider_id=provider_uuid).order_by("-report_period_start").first() - ) - - def get_usage_period_by_dates_and_cluster(self, start_date, end_date, cluster_id): - """Return all report period entries for the specified start date.""" - table_name = self._table_map["report_period"] - with schema_context(self.schema): - return ( - self._get_db_obj_query(table_name) - .filter(report_period_start=start_date, report_period_end=end_date, cluster_id=cluster_id) - .first() - ) - def get_usage_period_query_by_provider(self, provider_uuid): """Return all report periods for the specified provider.""" - table_name = self._table_map["report_period"] - with schema_context(self.schema): - return self._get_db_obj_query(table_name).filter(provider_id=provider_uuid) + return OCPUsageReportPeriod.objects.filter(provider_id=provider_uuid) def report_periods_for_provider_uuid(self, provider_uuid, start_date=None): """Return all report periods for provider_uuid on date.""" report_periods = self.get_usage_period_query_by_provider(provider_uuid) - with schema_context(self.schema): - if start_date: - if isinstance(start_date, str): - start_date = parse(start_date) - report_date = start_date.replace(day=1) - report_periods = report_periods.filter(report_period_start=report_date).first() - - return report_periods + if start_date: + if isinstance(start_date, str): + start_date = parse(start_date) + report_date = start_date.replace(day=1) + report_periods = report_periods.filter(report_period_start=report_date).first() + return report_periods def populate_ui_summary_tables(self, start_date, end_date, source_uuid, tables=UI_SUMMARY_TABLES): """Populate our UI summary tables (formerly materialized views).""" @@ -438,45 +414,16 @@ def populate_volume_label_summary_table(self, report_period_ids, start_date, end def populate_markup_cost(self, markup, start_date, end_date, cluster_id): """Set markup cost for OCP including infrastructure cost markup.""" - with schema_context(self.schema): - OCPUsageLineItemDailySummary.objects.filter( - cluster_id=cluster_id, usage_start__gte=start_date, usage_start__lte=end_date - ).update( - infrastructure_markup_cost=( - (Coalesce(F("infrastructure_raw_cost"), Value(0, output_field=DecimalField()))) * markup - ), - infrastructure_project_markup_cost=( - (Coalesce(F("infrastructure_project_raw_cost"), Value(0, output_field=DecimalField()))) * markup - ), - ) - - def get_distinct_nodes(self, start_date, end_date, cluster_id): - """Return a list of nodes for a cluster between given dates.""" - with schema_context(self.schema): - unique_nodes = ( - OCPUsageLineItemDailySummary.objects.filter( - usage_start__gte=start_date, usage_start__lt=end_date, cluster_id=cluster_id, node__isnull=False - ) - .values_list("node") - .distinct() - ) - return [node[0] for node in unique_nodes] - - def get_distinct_pvcs(self, start_date, end_date, cluster_id): - """Return a list of tuples of (PVC, node) for a cluster between given dates.""" - with schema_context(self.schema): - unique_pvcs = ( - OCPUsageLineItemDailySummary.objects.filter( - usage_start__gte=start_date, - usage_start__lt=end_date, - cluster_id=cluster_id, - persistentvolumeclaim__isnull=False, - namespace__isnull=False, - ) - .values_list("persistentvolumeclaim", "node", "namespace") - .distinct() - ) - return [(pvc[0], pvc[1], pvc[2]) for pvc in unique_pvcs] + OCPUsageLineItemDailySummary.objects.filter( + cluster_id=cluster_id, usage_start__gte=start_date, usage_start__lte=end_date + ).update( + infrastructure_markup_cost=( + (Coalesce(F("infrastructure_raw_cost"), Value(0, output_field=DecimalField()))) * markup + ), + infrastructure_project_markup_cost=( + (Coalesce(F("infrastructure_project_raw_cost"), Value(0, output_field=DecimalField()))) * markup + ), + ) def populate_platform_and_worker_distributed_cost_sql( self, start_date, end_date, provider_uuid, distribution_info @@ -499,9 +446,8 @@ def populate_platform_and_worker_distributed_cost_sql( context = {"schema": self.schema, "provider_uuid": provider_uuid, "start_date": start_date} LOG.info(log_json(msg=msg, context=context)) return - with schema_context(self.schema): - report_period_id = report_period.id + report_period_id = report_period.id distribute_mapping = { "platform_cost": { "sql_file": "distribute_platform_cost.sql", @@ -569,9 +515,7 @@ def populate_monthly_cost_sql(self, cost_type, rate_type, rate, start_date, end_ ) ) return - with schema_context(self.schema): - report_period_id = report_period.id - + report_period_id = report_period.id if not rate: LOG.info(log_json(msg="removing monthly costs", context=ctx)) self.delete_line_item_daily_summary_entries_for_date_range_raw( @@ -631,8 +575,7 @@ def populate_monthly_tag_cost_sql( # noqa: C901 ) ) return - with schema_context(self.schema): - report_period_id = report_period.id + report_period_id = report_period.id cpu_case, memory_case, volume_case = case_dict.get("cost") labels = case_dict.get("labels") @@ -721,8 +664,7 @@ def populate_usage_costs(self, rate_type, rates, start_date, end_date, provider_ ) ) return - with schema_context(self.schema): - report_period_id = report_period.id + report_period_id = report_period.id if not rates: LOG.info(log_json(msg="removing usage costs", context=ctx)) @@ -947,76 +889,72 @@ def populate_openshift_cluster_information_tables(self, provider, cluster_id, cl def populate_cluster_table(self, provider, cluster_id, cluster_alias): """Get or create an entry in the OCP cluster table.""" - with schema_context(self.schema): - LOG.info(log_json(msg="fetching entry in reporting_ocp_cluster", provider_uuid=provider.uuid)) - clusters = OCPCluster.objects.filter(provider_id=provider.uuid) - if clusters.count() > 1: - clusters_to_delete = clusters.exclude(cluster_alias=cluster_alias) - LOG.info( - log_json( - msg="attempting to delete duplicate entries in reporting_ocp_cluster", - provider_uuid=provider.uuid, - ) - ) - clusters_to_delete.delete() - cluster = clusters.first() - msg = "fetched entry in reporting_ocp_cluster" - if not cluster: - cluster, created = OCPCluster.objects.get_or_create( - cluster_id=cluster_id, cluster_alias=cluster_alias, provider_id=provider.uuid - ) - msg = f"created entry in reporting_ocp_clusters: {created}" - - # if the cluster entry already exists and cluster alias does not match, update the cluster alias - elif cluster.cluster_alias != cluster_alias: - cluster.cluster_alias = cluster_alias - cluster.save() - msg = "updated cluster entry with new cluster alias in reporting_ocp_clusters" - + LOG.info(log_json(msg="fetching entry in reporting_ocp_cluster", provider_uuid=provider.uuid)) + clusters = OCPCluster.objects.filter(provider_id=provider.uuid) + if clusters.count() > 1: + clusters_to_delete = clusters.exclude(cluster_alias=cluster_alias) LOG.info( log_json( - msg=msg, - cluster_id=cluster_id, - cluster_alias=cluster_alias, + msg="attempting to delete duplicate entries in reporting_ocp_cluster", provider_uuid=provider.uuid, ) ) + clusters_to_delete.delete() + cluster = clusters.first() + msg = "fetched entry in reporting_ocp_cluster" + if not cluster: + cluster, created = OCPCluster.objects.get_or_create( + cluster_id=cluster_id, cluster_alias=cluster_alias, provider_id=provider.uuid + ) + msg = f"created entry in reporting_ocp_clusters: {created}" + + # if the cluster entry already exists and cluster alias does not match, update the cluster alias + elif cluster.cluster_alias != cluster_alias: + cluster.cluster_alias = cluster_alias + cluster.save() + msg = "updated cluster entry with new cluster alias in reporting_ocp_clusters" + + LOG.info( + log_json( + msg=msg, + cluster_id=cluster_id, + cluster_alias=cluster_alias, + provider_uuid=provider.uuid, + ) + ) return cluster def populate_node_table(self, cluster, nodes): """Get or create an entry in the OCP node table.""" LOG.info(log_json(msg="populating reporting_ocp_nodes table", schema=self.schema, cluster=cluster)) - with schema_context(self.schema): - for node in nodes: - tmp_node = OCPNode.objects.filter( - node=node[0], resource_id=node[1], node_capacity_cpu_cores=node[2], cluster=cluster - ).first() - if not tmp_node: - OCPNode.objects.create( - node=node[0], - resource_id=node[1], - node_capacity_cpu_cores=node[2], - node_role=node[3], - cluster=cluster, - ) - # if the node entry already exists but does not have a role assigned, update the node role - elif not tmp_node.node_role: - tmp_node.node_role = node[3] - tmp_node.save() + for node in nodes: + tmp_node = OCPNode.objects.filter( + node=node[0], resource_id=node[1], node_capacity_cpu_cores=node[2], cluster=cluster + ).first() + if not tmp_node: + OCPNode.objects.create( + node=node[0], + resource_id=node[1], + node_capacity_cpu_cores=node[2], + node_role=node[3], + cluster=cluster, + ) + # if the node entry already exists but does not have a role assigned, update the node role + elif not tmp_node.node_role: + tmp_node.node_role = node[3] + tmp_node.save() def populate_pvc_table(self, cluster, pvcs): """Get or create an entry in the OCP cluster table.""" LOG.info(log_json(msg="populating reporting_ocp_pvcs table", schema=self.schema, cluster=cluster)) - with schema_context(self.schema): - for pvc in pvcs: - OCPPVC.objects.get_or_create(persistent_volume=pvc[0], persistent_volume_claim=pvc[1], cluster=cluster) + for pvc in pvcs: + OCPPVC.objects.get_or_create(persistent_volume=pvc[0], persistent_volume_claim=pvc[1], cluster=cluster) def populate_project_table(self, cluster, projects): """Get or create an entry in the OCP cluster table.""" LOG.info(log_json(msg="populating reporting_ocp_projects table", schema=self.schema, cluster=cluster)) - with schema_context(self.schema): - for project in projects: - OCPProject.objects.get_or_create(project=project, cluster=cluster) + for project in projects: + OCPProject.objects.get_or_create(project=project, cluster=cluster) def get_nodes_trino(self, source_uuid, start_date, end_date): """Get the nodes from an OpenShift cluster.""" @@ -1081,37 +1019,30 @@ def get_projects_trino(self, source_uuid, start_date, end_date): def get_cluster_for_provider(self, provider_uuid): """Return the cluster entry for a provider UUID.""" - with schema_context(self.schema): - cluster = OCPCluster.objects.filter(provider_id=provider_uuid).first() - return cluster + return OCPCluster.objects.filter(provider_id=provider_uuid).first() def get_nodes_for_cluster(self, cluster_id): """Get all nodes for an OCP cluster.""" - with schema_context(self.schema): - nodes = ( - OCPNode.objects.filter(cluster_id=cluster_id) - .exclude(node__exact="") - .values_list("node", "resource_id") - ) - nodes = [(node[0], node[1]) for node in nodes] + nodes = ( + OCPNode.objects.filter(cluster_id=cluster_id).exclude(node__exact="").values_list("node", "resource_id") + ) + nodes = [(node[0], node[1]) for node in nodes] return nodes def get_pvcs_for_cluster(self, cluster_id): """Get all nodes for an OCP cluster.""" - with schema_context(self.schema): - pvcs = ( - OCPPVC.objects.filter(cluster_id=cluster_id) - .exclude(persistent_volume__exact="") - .values_list("persistent_volume", "persistent_volume_claim") - ) - pvcs = [(pvc[0], pvc[1]) for pvc in pvcs] + pvcs = ( + OCPPVC.objects.filter(cluster_id=cluster_id) + .exclude(persistent_volume__exact="") + .values_list("persistent_volume", "persistent_volume_claim") + ) + pvcs = [(pvc[0], pvc[1]) for pvc in pvcs] return pvcs def get_projects_for_cluster(self, cluster_id): """Get all nodes for an OCP cluster.""" - with schema_context(self.schema): - projects = OCPProject.objects.filter(cluster_id=cluster_id).values_list("project") - projects = [project[0] for project in projects] + projects = OCPProject.objects.filter(cluster_id=cluster_id).values_list("project") + projects = [project[0] for project in projects] return projects def get_openshift_topology_for_multiple_providers(self, provider_uuids): diff --git a/koku/masu/test/database/test_ocp_report_db_accessor.py b/koku/masu/test/database/test_ocp_report_db_accessor.py index 9c472e3858..2e1fa556e4 100644 --- a/koku/masu/test/database/test_ocp_report_db_accessor.py +++ b/koku/masu/test/database/test_ocp_report_db_accessor.py @@ -12,25 +12,18 @@ from unittest.mock import Mock from unittest.mock import patch -from dateutil import relativedelta from django.conf import settings from django.db.models import Max from django.db.models import Q from django.db.models import Sum -from django.db.models.query import QuerySet -from django_tenants.utils import schema_context from trino.exceptions import TrinoExternalError from api.iam.test.iam_test_case import FakeTrinoConn from api.provider.models import Provider -from api.utils import DateHelper from koku import trino_database as trino_db -from masu.database import AWS_CUR_TABLE_MAP from masu.database import OCP_REPORT_TABLE_MAP from masu.database.ocp_report_db_accessor import OCPReportDBAccessor -from masu.external.date_accessor import DateAccessor from masu.test import MasuTestCase -from masu.test.database.helpers import ReportObjectCreator from reporting.models import OCPStorageVolumeLabelSummary from reporting.models import OCPUsageLineItemDailySummary from reporting.models import OCPUsagePodLabelSummary @@ -39,93 +32,43 @@ from reporting.provider.ocp.models import OCPNode from reporting.provider.ocp.models import OCPProject from reporting.provider.ocp.models import OCPPVC +from reporting.provider.ocp.models import OCPUsageReportPeriod class OCPReportDBAccessorTest(MasuTestCase): """Test Cases for the OCPReportDBAccessor object.""" - @classmethod - def setUpClass(cls): - """Set up the test class with required objects.""" - super().setUpClass() - - cls.accessor = OCPReportDBAccessor(schema=cls.schema) - cls.report_schema = cls.accessor.report_schema - cls.creator = ReportObjectCreator(cls.schema) - cls.all_tables = list(OCP_REPORT_TABLE_MAP.values()) - def setUp(self): """Set up a test with database objects.""" super().setUp() + self.accessor = OCPReportDBAccessor(schema=self.schema) + self.report_schema = self.accessor.report_schema + self.cluster_id = "testcluster" self.ocp_provider_uuid = self.ocp_provider.uuid - self.reporting_period = self.creator.create_ocp_report_period( - provider_uuid=self.ocp_provider_uuid, cluster_id=self.cluster_id - ) - def test_initializer(self): """Test initializer.""" self.assertIsNotNone(self.report_schema) - def test_get_db_obj_query_default(self): - """Test that a query is returned.""" - table_name = random.choice(self.all_tables) - - query = self.accessor._get_db_obj_query(table_name) - - self.assertIsInstance(query, QuerySet) - - def test_get_current_usage_period(self): - """Test that the most recent usage period is returned.""" - current_report_period = self.accessor.get_current_usage_period(self.ocp_provider_uuid) - self.assertIsNotNone(current_report_period.report_period_start) - self.assertIsNotNone(current_report_period.report_period_end) - - def test_get_usage_period_by_dates_and_cluster(self): - """Test that report periods are returned by dates & cluster filter.""" - period_start = DateAccessor().today_with_timezone("UTC").replace(day=1) - period_end = period_start + relativedelta.relativedelta(months=1) - prev_period_start = period_start - relativedelta.relativedelta(months=1) - prev_period_end = prev_period_start + relativedelta.relativedelta(months=1) - reporting_period = self.creator.create_ocp_report_period( - self.ocp_provider_uuid, period_date=period_start, cluster_id="0001" - ) - prev_reporting_period = self.creator.create_ocp_report_period( - self.ocp_provider_uuid, period_date=prev_period_start, cluster_id="0002" - ) - with schema_context(self.schema): - periods = self.accessor.get_usage_period_by_dates_and_cluster( - period_start.date(), period_end.date(), "0001" - ) - self.assertEqual(reporting_period, periods) - periods = self.accessor.get_usage_period_by_dates_and_cluster( - prev_period_start.date(), prev_period_end.date(), "0002" - ) - self.assertEqual(prev_reporting_period, periods) - def test_get_usage_period_query_by_provider(self): """Test that periods are returned filtered by provider.""" - provider_uuid = self.ocp_provider_uuid - - period_query = self.accessor.get_usage_period_query_by_provider(provider_uuid) - with schema_context(self.schema): + with self.accessor as acc: + provider_uuid = self.ocp_provider_uuid + period_query = acc.get_usage_period_query_by_provider(provider_uuid) periods = period_query.all() - self.assertGreater(len(periods), 0) - period = periods[0] - self.assertEqual(period.provider_id, provider_uuid) def test_report_periods_for_provider_uuid(self): """Test that periods are returned filtered by provider id and start date.""" - provider_uuid = self.ocp_provider_uuid - start_date = str(self.reporting_period.report_period_start) - - period = self.accessor.report_periods_for_provider_uuid(provider_uuid, start_date) - with schema_context(self.schema): + with self.accessor as acc: + provider_uuid = self.ocp_provider_uuid + reporting_period = OCPUsageReportPeriod.objects.filter(provider=self.ocp_provider_uuid).first() + start_date = str(reporting_period.report_period_start) + period = acc.report_periods_for_provider_uuid(provider_uuid, start_date) self.assertEqual(period.provider_id, provider_uuid) @patch("masu.database.ocp_report_db_accessor.trino_table_exists") @@ -135,17 +78,18 @@ def test_populate_line_item_daily_summary_table_trino(self, mock_execute, *args) """ Test that OCP trino processing calls executescript """ - dh = DateHelper() - start_date = dh.this_month_start - end_date = dh.next_month_start + + start_date = self.dh.this_month_start + end_date = self.dh.next_month_start cluster_id = "ocp-cluster" cluster_alias = "OCP FTW" report_period_id = 1 source = self.provider_uuid - self.accessor.populate_line_item_daily_summary_table_trino( - start_date, end_date, report_period_id, cluster_id, cluster_alias, source - ) - mock_execute.assert_called() + with self.accessor as acc: + acc.populate_line_item_daily_summary_table_trino( + start_date, end_date, report_period_id, cluster_id, cluster_alias, source + ) + mock_execute.assert_called() @patch("masu.database.ocp_report_db_accessor.trino_table_exists") @patch("masu.database.ocp_report_db_accessor.pkgutil.get_data") @@ -190,12 +134,10 @@ def test_populate_tag_based_usage_costs(self): # noqa: C901 "storage_gb_usage_per_month": ["storage", "persistentvolumeclaim_usage_gigabyte_months"], "storage_gb_request_per_month": ["storage", "volume_request_storage_gigabyte_months"], } - - dh = DateHelper() - start_date = dh.this_month_start - end_date = dh.this_month_end + start_date = self.dh.this_month_start + end_date = self.dh.this_month_end self.cluster_id = "OCP-on-AWS" - with schema_context(self.schema): + with self.accessor as acc: # define the two usage types to test usage_types = ("Infrastructure", "Supplementary") for usage_type in usage_types: @@ -273,7 +215,7 @@ def test_populate_tag_based_usage_costs(self): # noqa: C901 ) # call populate monthly tag_cost with the rates defined above - self.accessor.populate_tag_usage_costs( + acc.populate_tag_usage_costs( infrastructure_rates, supplementary_rates, start_date, end_date, self.cluster_id ) @@ -354,12 +296,10 @@ def test_populate_tag_based_default_usage_costs(self): # noqa: C901 "storage_gb_usage_per_month": ["storage", "persistentvolumeclaim_usage_gigabyte_months"], "storage_gb_request_per_month": ["storage", "volume_request_storage_gigabyte_months"], } - - dh = DateHelper() - start_date = dh.this_month_start - end_date = dh.this_month_end + start_date = self.dh.this_month_start + end_date = self.dh.this_month_end self.cluster_id = "OCP-on-AWS" - with schema_context(self.schema): + with self.accessor as acc: # define the two usage types to test usage_types = ("Infrastructure", "Supplementary") for usage_type in usage_types: @@ -435,7 +375,7 @@ def test_populate_tag_based_default_usage_costs(self): # noqa: C901 ) # call populate monthly tag_cost with the rates defined above - self.accessor.populate_tag_usage_default_costs( + acc.populate_tag_usage_default_costs( infrastructure_rates, supplementary_rates, start_date, end_date, self.cluster_id ) @@ -510,20 +450,18 @@ def test_populate_tag_based_default_usage_costs(self): # noqa: C901 def test_update_line_item_daily_summary_with_enabled_tags(self): """Test that we filter the daily summary table's tags with only enabled tags.""" - dh = DateHelper() - start_date = dh.this_month_start.date() - end_date = dh.this_month_end.date() - - report_period = self.accessor.report_periods_for_provider_uuid(self.ocp_provider_uuid, start_date) + start_date = self.dh.this_month_start.date() + end_date = self.dh.this_month_end.date() + with self.accessor as acc: + report_period = acc.report_periods_for_provider_uuid(self.ocp_provider_uuid, start_date) - with schema_context(self.schema): OCPUsagePodLabelSummary.objects.all().delete() OCPStorageVolumeLabelSummary.objects.all().delete() key_to_keep = EnabledTagKeys.objects.filter(provider_type=Provider.PROVIDER_OCP).filter(key="app").first() EnabledTagKeys.objects.filter(provider_type=Provider.PROVIDER_OCP).update(enabled=False) EnabledTagKeys.objects.filter(provider_type=Provider.PROVIDER_OCP).filter(key="app").update(enabled=True) report_period_ids = [report_period.id] - self.accessor.update_line_item_daily_summary_with_enabled_tags(start_date, end_date, report_period_ids) + acc.update_line_item_daily_summary_with_enabled_tags(start_date, end_date, report_period_ids) tags = ( OCPUsageLineItemDailySummary.objects.filter( usage_start__gte=start_date, report_period_id__in=report_period_ids @@ -558,21 +496,14 @@ def test_update_line_item_daily_summary_with_enabled_tags(self): def test_delete_line_item_daily_summary_entries_for_date_range(self): """Test that daily summary rows are deleted.""" - with schema_context(self.schema): + with self.accessor as acc: start_date = OCPUsageLineItemDailySummary.objects.aggregate(Max("usage_start")).get("usage_start__max") end_date = start_date - - table_query = OCPUsageLineItemDailySummary.objects.filter( - source_uuid=self.ocp_provider_uuid, usage_start__gte=start_date, usage_start__lte=end_date - ) - with schema_context(self.schema): + table_query = OCPUsageLineItemDailySummary.objects.filter( + source_uuid=self.ocp_provider_uuid, usage_start__gte=start_date, usage_start__lte=end_date + ) self.assertNotEqual(table_query.count(), 0) - - self.accessor.delete_line_item_daily_summary_entries_for_date_range( - self.ocp_provider_uuid, start_date, end_date - ) - - with schema_context(self.schema): + acc.delete_line_item_daily_summary_entries_for_date_range(self.ocp_provider_uuid, start_date, end_date) self.assertEqual(table_query.count(), 0) def test_table_properties(self): @@ -580,14 +511,13 @@ def test_table_properties(self): def test_table_map(self): self.assertEqual(self.accessor._table_map, OCP_REPORT_TABLE_MAP) - self.assertEqual(self.accessor._aws_table_map, AWS_CUR_TABLE_MAP) @patch("masu.database.ocp_report_db_accessor.OCPReportDBAccessor._execute_trino_raw_sql_query") def test_get_ocp_infrastructure_map_trino(self, mock_trino): """Test that Trino is used to find matched tags.""" - dh = DateHelper() - start_date = dh.this_month_start.date() - end_date = dh.this_month_end.date() + + start_date = self.dh.this_month_start.date() + end_date = self.dh.this_month_end.date() self.accessor.get_ocp_infrastructure_map_trino(start_date, end_date) mock_trino.assert_called() @@ -595,9 +525,9 @@ def test_get_ocp_infrastructure_map_trino(self, mock_trino): @patch("masu.database.ocp_report_db_accessor.OCPReportDBAccessor._execute_trino_raw_sql_query") def test_get_ocp_infrastructure_map_trino_gcp_resource(self, mock_trino): """Test that Trino is used to find matched resource names.""" - dh = DateHelper() - start_date = dh.this_month_start.date() - end_date = dh.this_month_end.date() + + start_date = self.dh.this_month_start.date() + end_date = self.dh.this_month_end.date() expected_log = "INFO:masu.util.gcp.common:OCP GCP matching set to resource level" with self.assertLogs("masu.util.gcp.common", level="INFO") as logger: self.accessor.get_ocp_infrastructure_map_trino( @@ -609,9 +539,9 @@ def test_get_ocp_infrastructure_map_trino_gcp_resource(self, mock_trino): @patch("masu.database.ocp_report_db_accessor.OCPReportDBAccessor._execute_trino_raw_sql_query") def test_get_ocp_infrastructure_map_trino_gcp_with_disabled_resource_matching(self, mock_trino): """Test that Trino is used to find matched resource names.""" - dh = DateHelper() - start_date = dh.this_month_start.date() - end_date = dh.this_month_end.date() + + start_date = self.dh.this_month_start.date() + end_date = self.dh.this_month_end.date() expected_log = f"INFO:masu.util.gcp.common:GCP resource matching disabled for {self.schema}" with patch("masu.util.gcp.common.is_gcp_resource_matching_disabled", return_value=True): with self.assertLogs("masu", level="INFO") as logger: @@ -642,15 +572,16 @@ def test_populate_openshift_cluster_information_tables( mock_table.return_value = True cluster_id = uuid.uuid4() cluster_alias = "test-cluster-1" - dh = DateHelper() - start_date = dh.this_month_start.date() - end_date = dh.this_month_end.date() - self.accessor.populate_openshift_cluster_information_tables( - self.aws_provider, cluster_id, cluster_alias, start_date, end_date - ) + start_date = self.dh.this_month_start.date() + end_date = self.dh.this_month_end.date() + + with self.accessor as acc: + + acc.populate_openshift_cluster_information_tables( + self.aws_provider, cluster_id, cluster_alias, start_date, end_date + ) - with schema_context(self.schema): self.assertIsNotNone(OCPCluster.objects.filter(cluster_id=cluster_id).first()) for node in nodes: db_node = OCPNode.objects.filter(node=node).first() @@ -665,14 +596,14 @@ def test_populate_openshift_cluster_information_tables( for project in projects: self.assertIsNotNone(OCPProject.objects.filter(project=project).first()) - mock_table.reset_mock() - mock_get_pvcs.reset_mock() - mock_table.return_value = False + mock_table.reset_mock() + mock_get_pvcs.reset_mock() + mock_table.return_value = False - self.accessor.populate_openshift_cluster_information_tables( - self.ocp_provider, cluster_id, cluster_alias, start_date, end_date - ) - mock_get_pvcs.assert_not_called() + acc.populate_openshift_cluster_information_tables( + self.ocp_provider, cluster_id, cluster_alias, start_date, end_date + ) + mock_get_pvcs.assert_not_called() @patch("masu.database.ocp_report_db_accessor.trino_table_exists") @patch("masu.database.ocp_report_db_accessor.OCPReportDBAccessor.get_projects_trino") @@ -695,22 +626,22 @@ def test_get_openshift_topology_for_multiple_providers( mock_table.return_value = True cluster_id = str(uuid.uuid4()) cluster_alias = "test-cluster-1" - dh = DateHelper() - start_date = dh.this_month_start.date() - end_date = dh.this_month_end.date() - # Using the aws_provider to short cut this test instead of creating a brand - # new provider. The OCP providers already have data, and can't be used here - self.accessor.populate_openshift_cluster_information_tables( - self.aws_provider, cluster_id, cluster_alias, start_date, end_date - ) + start_date = self.dh.this_month_start.date() + end_date = self.dh.this_month_end.date() + + with self.accessor as acc: + # Using the aws_provider to short cut this test instead of creating a brand + # new provider. The OCP providers already have data, and can't be used here + acc.populate_openshift_cluster_information_tables( + self.aws_provider, cluster_id, cluster_alias, start_date, end_date + ) - with schema_context(self.schema): cluster = OCPCluster.objects.filter(cluster_id=cluster_id).first() nodes = OCPNode.objects.filter(cluster=cluster).all() pvcs = OCPPVC.objects.filter(cluster=cluster).all() projects = OCPProject.objects.filter(cluster=cluster).all() - topology = self.accessor.get_openshift_topology_for_multiple_providers([self.aws_provider_uuid]) + topology = acc.get_openshift_topology_for_multiple_providers([self.aws_provider_uuid]) self.assertEqual(len(topology), 1) topo = topology[0] self.assertEqual(topo.get("cluster_id"), cluster_id) @@ -739,16 +670,16 @@ def test_get_filtered_openshift_topology_for_multiple_providers(self, mock_get_n mock_table.return_value = True cluster_id = str(uuid.uuid4()) cluster_alias = "test-cluster-1" - dh = DateHelper() - start_date = dh.this_month_start.date() - end_date = dh.this_month_end.date() - with schema_context(self.schema): + start_date = self.dh.this_month_start.date() + end_date = self.dh.this_month_end.date() + + with self.accessor as acc: cluster = OCPCluster( cluster_id=cluster_id, cluster_alias=cluster_alias, provider_id=self.gcp_provider_uuid ) cluster.save() - topology = self.accessor.get_filtered_openshift_topology_for_multiple_providers( + topology = acc.get_filtered_openshift_topology_for_multiple_providers( [self.gcp_provider_uuid], start_date, end_date ) self.assertEqual(len(topology), 1) @@ -765,13 +696,13 @@ def test_populate_node_table_update_role(self): node_info = ["node_role_test_node", "node_role_test_id", 1, "worker"] cluster_id = str(uuid.uuid4()) cluster_alias = "node_role_test" - cluster = self.accessor.populate_cluster_table(self.aws_provider, cluster_id, cluster_alias) - with schema_context(self.schema): + with self.accessor as acc: + cluster = acc.populate_cluster_table(self.aws_provider, cluster_id, cluster_alias) node = OCPNode.objects.create( node=node_info[0], resource_id=node_info[1], node_capacity_cpu_cores=node_info[2], cluster=cluster ) self.assertIsNone(node.node_role) - self.accessor.populate_node_table(cluster, [node_info]) + acc.populate_node_table(cluster, [node_info]) node = OCPNode.objects.get( node=node_info[0], resource_id=node_info[1], node_capacity_cpu_cores=node_info[2], cluster=cluster ) @@ -782,12 +713,11 @@ def test_populate_cluster_table_update_cluster_alias(self): cluster_id = str(uuid.uuid4()) cluster_alias = "cluster_alias" new_cluster_alias = "new_cluster_alias" - self.accessor.populate_cluster_table(self.aws_provider, cluster_id, cluster_alias) - - with schema_context(self.schema): + with self.accessor as acc: + acc.populate_cluster_table(self.aws_provider, cluster_id, cluster_alias) cluster = OCPCluster.objects.filter(cluster_id=cluster_id).first() self.assertEqual(cluster.cluster_alias, cluster_alias) - self.accessor.populate_cluster_table(self.aws_provider, cluster_id, new_cluster_alias) + acc.populate_cluster_table(self.aws_provider, cluster_id, new_cluster_alias) cluster = OCPCluster.objects.filter(cluster_id=cluster_id).first() self.assertEqual(cluster.cluster_alias, new_cluster_alias) @@ -795,13 +725,13 @@ def test_populate_cluster_table_delete_duplicates(self): """Test updating cluster alias for duplicate entry in the cluster table.""" cluster_id = str(uuid.uuid4()) new_cluster_alias = "new_cluster_alias" - self.accessor.populate_cluster_table(self.aws_provider, cluster_id, "cluster_alias") - with schema_context(self.schema): + with self.accessor as acc: + acc.populate_cluster_table(self.aws_provider, cluster_id, "cluster_alias") # Forcefully create a second entry OCPCluster.objects.get_or_create( cluster_id=cluster_id, cluster_alias=self.aws_provider.name, provider_id=self.aws_provider_uuid ) - self.accessor.populate_cluster_table(self.aws_provider, cluster_id, new_cluster_alias) + acc.populate_cluster_table(self.aws_provider, cluster_id, new_cluster_alias) clusters = OCPCluster.objects.filter(cluster_id=cluster_id) self.assertEqual(len(clusters), 1) cluster = clusters.first() @@ -812,14 +742,14 @@ def test_populate_node_table_second_time_no_change(self): node_info = ["node_role_test_node", "node_role_test_id", 1, "worker"] cluster_id = str(uuid.uuid4()) cluster_alias = "node_role_test" - cluster = self.accessor.populate_cluster_table(self.aws_provider, cluster_id, cluster_alias) - with schema_context(self.schema): - self.accessor.populate_node_table(cluster, [node_info]) + with self.accessor as acc: + cluster = acc.populate_cluster_table(self.aws_provider, cluster_id, cluster_alias) + acc.populate_node_table(cluster, [node_info]) node_count = OCPNode.objects.filter( node=node_info[0], resource_id=node_info[1], node_capacity_cpu_cores=node_info[2], cluster=cluster ).count() self.assertEqual(node_count, 1) - self.accessor.populate_node_table(cluster, [node_info]) + acc.populate_node_table(cluster, [node_info]) node_count = OCPNode.objects.filter( node=node_info[0], resource_id=node_info[1], node_capacity_cpu_cores=node_info[2], cluster=cluster ).count() @@ -827,26 +757,22 @@ def test_populate_node_table_second_time_no_change(self): def test_delete_infrastructure_raw_cost_from_daily_summary(self): """Test that infra raw cost is deleted.""" - dh = DateHelper() - start_date = dh.this_month_start.date() - end_date = dh.this_month_end.date() - report_period = self.accessor.report_periods_for_provider_uuid(self.ocpaws_provider_uuid, start_date) - with schema_context(self.schema): + with self.accessor as acc: + start_date = self.dh.this_month_start.date() + end_date = self.dh.this_month_end.date() + report_period = acc.report_periods_for_provider_uuid(self.ocpaws_provider_uuid, start_date) report_period_id = report_period.id count = OCPUsageLineItemDailySummary.objects.filter( report_period_id=report_period_id, usage_start__gte=start_date, infrastructure_raw_cost__gt=0 ).count() - self.assertNotEqual(count, 0) - - self.accessor.delete_infrastructure_raw_cost_from_daily_summary( - self.ocpaws_provider_uuid, report_period_id, start_date, end_date - ) - - with schema_context(self.schema): + self.assertNotEqual(count, 0) + acc.delete_infrastructure_raw_cost_from_daily_summary( + self.ocpaws_provider_uuid, report_period_id, start_date, end_date + ) count = OCPUsageLineItemDailySummary.objects.filter( report_period_id=report_period_id, usage_start__gte=start_date, infrastructure_raw_cost__gt=0 ).count() - self.assertEqual(count, 0) + self.assertEqual(count, 0) @patch("masu.database.ocp_report_db_accessor.OCPReportDBAccessor.table_exists_trino") @patch("masu.database.ocp_report_db_accessor.OCPReportDBAccessor._execute_trino_raw_sql_query") @@ -930,15 +856,14 @@ def test_get_max_min_timestamp_from_parquet(self, mock_query): def test_delete_all_except_infrastructure_raw_cost_from_daily_summary(self): """Test that deleting saves OCP on Cloud data.""" - dh = DateHelper() - start_date = dh.this_month_start - end_date = dh.this_month_end + with self.accessor as acc: + start_date = self.dh.this_month_start + end_date = self.dh.this_month_end - # First test an OCP on Cloud source to make sure we don't delete that data - provider_uuid = self.ocp_on_aws_ocp_provider.uuid - report_period = self.accessor.report_periods_for_provider_uuid(provider_uuid, start_date) + # First test an OCP on Cloud source to make sure we don't delete that data + provider_uuid = self.ocp_on_aws_ocp_provider.uuid + report_period = acc.report_periods_for_provider_uuid(provider_uuid, start_date) - with schema_context(self.schema): report_period_id = report_period.id initial_non_raw_count = ( OCPUsageLineItemDailySummary.objects.filter( @@ -954,11 +879,10 @@ def test_delete_all_except_infrastructure_raw_cost_from_daily_summary(self): report_period_id=report_period_id, ).count() - self.accessor.delete_all_except_infrastructure_raw_cost_from_daily_summary( - provider_uuid, report_period_id, start_date, end_date - ) + acc.delete_all_except_infrastructure_raw_cost_from_daily_summary( + provider_uuid, report_period_id, start_date, end_date + ) - with schema_context(self.schema): new_non_raw_count = OCPUsageLineItemDailySummary.objects.filter( Q(infrastructure_raw_cost__isnull=True) | Q(infrastructure_raw_cost=0), report_period_id=report_period_id, @@ -968,15 +892,14 @@ def test_delete_all_except_infrastructure_raw_cost_from_daily_summary(self): report_period_id=report_period_id, ).count() - self.assertEqual(initial_non_raw_count, 0) - self.assertEqual(new_non_raw_count, 0) - self.assertEqual(initial_raw_count, new_raw_count) + self.assertEqual(initial_non_raw_count, 0) + self.assertEqual(new_non_raw_count, 0) + self.assertEqual(initial_raw_count, new_raw_count) - # Now test an on prem OCP cluster to make sure we still remove non raw costs - provider_uuid = self.ocp_provider.uuid - report_period = self.accessor.report_periods_for_provider_uuid(provider_uuid, start_date) + # Now test an on prem OCP cluster to make sure we still remove non raw costs + provider_uuid = self.ocp_provider.uuid + report_period = acc.report_periods_for_provider_uuid(provider_uuid, start_date) - with schema_context(self.schema): report_period_id = report_period.id initial_non_raw_count = OCPUsageLineItemDailySummary.objects.filter( Q(infrastructure_raw_cost__isnull=True) | Q(infrastructure_raw_cost=0), @@ -987,11 +910,10 @@ def test_delete_all_except_infrastructure_raw_cost_from_daily_summary(self): report_period_id=report_period_id, ).count() - self.accessor.delete_all_except_infrastructure_raw_cost_from_daily_summary( - provider_uuid, report_period_id, start_date, end_date - ) + acc.delete_all_except_infrastructure_raw_cost_from_daily_summary( + provider_uuid, report_period_id, start_date, end_date + ) - with schema_context(self.schema): new_non_raw_count = OCPUsageLineItemDailySummary.objects.filter( Q(infrastructure_raw_cost__isnull=True) | Q(infrastructure_raw_cost=0), report_period_id=report_period_id, @@ -1001,42 +923,46 @@ def test_delete_all_except_infrastructure_raw_cost_from_daily_summary(self): report_period_id=report_period_id, ).count() - self.assertNotEqual(initial_non_raw_count, new_non_raw_count) - self.assertEqual(initial_raw_count, 0) - self.assertEqual(new_raw_count, 0) + self.assertNotEqual(initial_non_raw_count, new_non_raw_count) + self.assertEqual(initial_raw_count, 0) + self.assertEqual(new_raw_count, 0) def test_populate_monthly_cost_sql_no_report_period(self): """Test that updating monthly costs without a matching report period no longer throws an error""" start_date = "2000-01-01" end_date = "2000-02-01" with self.assertLogs("masu.database.ocp_report_db_accessor", level="INFO") as logger: - self.accessor.populate_monthly_cost_sql("", "", "", start_date, end_date, "", self.provider_uuid) - self.assertIn("no report period for OCP provider", logger.output[0]) + with self.accessor as acc: + acc.populate_monthly_cost_sql("", "", "", start_date, end_date, "", self.provider_uuid) + self.assertIn("no report period for OCP provider", logger.output[0]) def test_populate_monthly_cost_tag_sql_no_report_period(self): """Test that updating monthly costs without a matching report period no longer throws an error""" start_date = "2000-01-01" end_date = "2000-02-01" with self.assertLogs("masu.database.ocp_report_db_accessor", level="INFO") as logger: - self.accessor.populate_monthly_tag_cost_sql("", "", "", "", start_date, end_date, "", self.provider_uuid) - self.assertIn("no report period for OCP provider", logger.output[0]) + with self.accessor as acc: + acc.populate_monthly_tag_cost_sql("", "", "", "", start_date, end_date, "", self.provider_uuid) + self.assertIn("no report period for OCP provider", logger.output[0]) def test_populate_usage_costs_new_columns_no_report_period(self): """Test that updating new column usage costs without a matching report period no longer throws an error""" start_date = "2000-01-01" end_date = "2000-02-01" with self.assertLogs("masu.database.ocp_report_db_accessor", level="INFO") as logger: - self.accessor.populate_usage_costs("", "", start_date, end_date, self.provider_uuid) - self.assertIn("no report period for OCP provider", logger.output[0]) + with self.accessor as acc: + acc.populate_usage_costs("", "", start_date, end_date, self.provider_uuid) + self.assertIn("no report period for OCP provider", logger.output[0]) def test_populate_platform_and_worker_distributed_cost_sql_no_report_period(self): """Test that updating monthly costs without a matching report period no longer throws an error""" start_date = "2000-01-01" end_date = "2000-02-01" - result = self.accessor.populate_platform_and_worker_distributed_cost_sql( - start_date, end_date, self.provider_uuid, {"platform_cost": True} - ) - self.assertIsNone(result) + with self.accessor as acc: + result = acc.populate_platform_and_worker_distributed_cost_sql( + start_date, end_date, self.provider_uuid, {"platform_cost": True} + ) + self.assertIsNone(result) @patch("masu.database.ocp_report_db_accessor.pkgutil.get_data") @patch("masu.database.ocp_report_db_accessor.OCPReportDBAccessor._execute_raw_sql_query") @@ -1052,7 +978,6 @@ def get_pkgutil_values(file): masu_database = "masu.database" start_date = self.dh.this_month_start.date() end_date = self.dh.this_month_end.date() - accessor = OCPReportDBAccessor(schema=self.schema) default_sql_params = { "start_date": start_date, "end_date": end_date, @@ -1068,15 +993,17 @@ def get_pkgutil_values(file): ] mock_jinja = Mock() mock_jinja.side_effect = side_effect - accessor.prepare_query = mock_jinja - accessor.populate_platform_and_worker_distributed_cost_sql( - start_date, end_date, self.ocp_test_provider_uuid, {"worker_cost": True, "platform_cost": True} - ) - expected_calls = [ - call(masu_database, "sql/openshift/cost_model/distribute_worker_cost.sql"), - call(masu_database, "sql/openshift/cost_model/distribute_platform_cost.sql"), - ] - for expected_call in expected_calls: - self.assertIn(expected_call, mock_data_get.call_args_list) - mock_sql_execute.assert_called() - self.assertEqual(len(mock_sql_execute.call_args_list), 2) + + with self.accessor as acc: + acc.prepare_query = mock_jinja + acc.populate_platform_and_worker_distributed_cost_sql( + start_date, end_date, self.ocp_test_provider_uuid, {"worker_cost": True, "platform_cost": True} + ) + expected_calls = [ + call(masu_database, "sql/openshift/cost_model/distribute_worker_cost.sql"), + call(masu_database, "sql/openshift/cost_model/distribute_platform_cost.sql"), + ] + for expected_call in expected_calls: + self.assertIn(expected_call, mock_data_get.call_args_list) + mock_sql_execute.assert_called() + self.assertEqual(len(mock_sql_execute.call_args_list), 2) diff --git a/koku/masu/test/processor/ocp/test_ocp_cost_model_cost_updater.py b/koku/masu/test/processor/ocp/test_ocp_cost_model_cost_updater.py index faa7263d9d..e3b9ba8eca 100644 --- a/koku/masu/test/processor/ocp/test_ocp_cost_model_cost_updater.py +++ b/koku/masu/test/processor/ocp/test_ocp_cost_model_cost_updater.py @@ -20,6 +20,7 @@ from masu.test import MasuTestCase from masu.util.ocp.common import get_amortized_monthly_cost_model_rate from reporting.models import OCPUsageLineItemDailySummary +from reporting.provider.ocp.models import OCPUsageReportPeriod class OCPCostModelCostUpdaterTest(MasuTestCase): @@ -159,8 +160,13 @@ def test_update_monthly_cost_infrastructure(self, mock_cost_accessor): mock_cost_accessor.return_value.__enter__.return_value.infrastructure_rates = infrastructure_rates mock_cost_accessor.return_value.__enter__.return_value.supplementary_rates = {} mock_cost_accessor.return_value.__enter__.return_value.distribution_info = self.distribution_info + with schema_context(self.schema): + usage_period = ( + OCPUsageReportPeriod.objects.filter(provider_id=self.provider_uuid) + .order_by("-report_period_start") + .first() + ) - usage_period = self.accessor.get_current_usage_period(self.provider_uuid) start_date = usage_period.report_period_start.date() end_date = usage_period.report_period_end.date() - relativedelta(days=1) updater = OCPCostModelCostUpdater(schema=self.schema, provider=self.provider) @@ -200,12 +206,17 @@ def test_update_monthly_cost_supplementary(self, mock_cost_accessor): "platform_cost": False, "worker_cost": False, } - usage_period = self.accessor.get_current_usage_period(self.provider_uuid) - start_date = usage_period.report_period_start.date() + relativedelta(days=-1) - end_date = usage_period.report_period_end.date() + relativedelta(days=+1) + with self.accessor: + usage_period = ( + OCPUsageReportPeriod.objects.filter(provider_id=self.provider_uuid) + .order_by("-report_period_start") + .first() + ) + start_date = usage_period.report_period_start.date() + relativedelta(days=-1) + end_date = usage_period.report_period_end.date() + relativedelta(days=+1) updater = OCPCostModelCostUpdater(schema=self.schema, provider=self.provider) updater._update_monthly_cost(start_date, end_date) - with schema_context(self.schema): + with self.accessor: monthly_cost_row = OCPUsageLineItemDailySummary.objects.filter( cost_model_rate_type="Supplementary", monthly_cost_type__isnull=False, @@ -309,7 +320,12 @@ def test_update_tag_usage_costs(self, mock_cost_accessor, mock_update_usage): mock_cost_accessor.return_value.__enter__.return_value.tag_infrastructure_rates = infrastructure_rates mock_cost_accessor.return_value.__enter__.return_value.tag_supplementary_rates = supplementary_rates - usage_period = self.accessor.get_current_usage_period(self.provider_uuid) + with schema_context(self.schema): + usage_period = ( + OCPUsageReportPeriod.objects.filter(provider_id=self.provider_uuid) + .order_by("-report_period_start") + .first() + ) start_date = usage_period.report_period_start.date() + relativedelta(days=-1) end_date = usage_period.report_period_end.date() + relativedelta(days=+1) updater = OCPCostModelCostUpdater(schema=self.schema, provider=self.provider)