diff --git a/koku/subs/subs_data_messenger.py b/koku/subs/subs_data_messenger.py index 9bb490546b..b922bcd503 100644 --- a/koku/subs/subs_data_messenger.py +++ b/koku/subs/subs_data_messenger.py @@ -49,22 +49,21 @@ def __init__(self, context, schema_name, tracing_id): self.instance_map = {} self.date_map = defaultdict(list) - def determine_azure_instance_id(self, row): - """For Azure we have to query the instance id if its not provided by a tag.""" + def determine_azure_instance_and_tenant_id(self, row): + """For Azure we have to query the instance id if its not provided by a tag and the tenant_id.""" if row["subs_resource_id"] in self.instance_map: return self.instance_map.get(row["subs_resource_id"]) - # this column comes from a user defined tag allowing us to avoid querying Azure if its present. + prov = Provider.objects.get(uuid=row["source"]) + credentials = prov.account.get("credentials") + tenant_id = credentials.get("tenant_id") if row["subs_instance"] != "": instance_id = row["subs_instance"] # attempt to query azure for instance id else: # if its a local Azure provider, don't query Azure if self.local_prov: - return "" - prov = Provider.objects.get(uuid=row["source"]) - credentials = prov.account.get("credentials") + return "", tenant_id subscription_id = credentials.get("subscription_id") - tenant_id = credentials.get("tenant_id") client_id = credentials.get("client_id") client_secret = credentials.get("client_secret") _factory = AzureClientFactory(subscription_id, tenant_id, client_id, client_secret) @@ -76,8 +75,8 @@ def determine_azure_instance_id(self, row): ) instance_id = response.vm_id - self.instance_map[row["subs_resource_id"]] = instance_id - return instance_id + self.instance_map[row["subs_resource_id"]] = (instance_id, tenant_id) + return instance_id, tenant_id def process_and_send_subs_message(self, upload_keys): """ @@ -101,7 +100,7 @@ def process_and_send_subs_message(self, upload_keys): msg_count += self.process_azure_row(row) else: # row["subs_product_ids"] is a string of numbers separated by '-' to be sent as a list - msg = self.build_subs_msg( + subs_dict = self.build_subs_dict( row["subs_resource_id"], row["subs_account"], row["subs_start_time"], @@ -112,6 +111,7 @@ def process_and_send_subs_message(self, upload_keys): row["subs_role"], row["subs_product_ids"].split("-"), ) + msg = bytes(json.dumps(subs_dict), "utf-8") self.send_kafka_message(msg) msg_count += 1 LOG.info( @@ -130,11 +130,11 @@ def send_kafka_message(self, msg): producer.produce(SUBS_TOPIC, value=msg, callback=delivery_callback) producer.poll(0) - def build_subs_msg( + def build_subs_dict( self, instance_id, billing_account_id, tstamp, expiration, cpu_count, sla, usage, role, product_ids ): - """Gathers the relevant information for the kafka message and returns the message to be delivered.""" - subs_json = { + """Gathers the relevant information for the kafka message and returns a filled dictionary of information.""" + return { "event_id": str(uuid.uuid4()), "event_source": "cost-management", "event_type": "snapshot", @@ -154,7 +154,16 @@ def build_subs_msg( "billing_provider": self.provider_type.lower(), "billing_account_id": billing_account_id, } - return bytes(json.dumps(subs_json), "utf-8") + + def build_azure_subs_dict( + self, instance_id, billing_account_id, tstamp, expiration, cpu_count, sla, usage, role, product_ids, tenant_id + ): + """Adds azure_tenant_id to the base subs dict.""" + subs_dict = self.build_subs_dict( + instance_id, billing_account_id, tstamp, expiration, cpu_count, sla, usage, role, product_ids + ) + subs_dict["azure_tenant_id"] = tenant_id + return subs_dict def process_azure_row(self, row): """Process an Azure row into subs kafka messages.""" @@ -166,14 +175,14 @@ def process_azure_row(self, row): ): return msg_count self.date_map[row["subs_start_time"]].append(row["subs_resource_id"]) - instance_id = self.determine_azure_instance_id(row) + instance_id, tenant_id = self.determine_azure_instance_and_tenant_id(row) if not instance_id: return msg_count # Azure is daily records but subs need hourly records start = parser.parse(row["subs_start_time"]) for i in range(int(row["subs_usage_quantity"])): end = start + timedelta(hours=1) - msg = self.build_subs_msg( + subs_dict = self.build_azure_subs_dict( instance_id, row["subs_account"], start.isoformat(), @@ -183,7 +192,9 @@ def process_azure_row(self, row): row["subs_usage"], row["subs_role"], row["subs_product_ids"].split("-"), + tenant_id, ) + msg = bytes(json.dumps(subs_dict), "utf-8") # move to the next hour in the range start = end self.send_kafka_message(msg) diff --git a/koku/subs/test/__init__.py b/koku/subs/test/__init__.py index b3090710f6..5e31bb0704 100644 --- a/koku/subs/test/__init__.py +++ b/koku/subs/test/__init__.py @@ -19,4 +19,5 @@ def setUpClass(cls): cls.aws_provider_type = Provider.PROVIDER_AWS_LOCAL cls.azure_provider = Provider.objects.filter(type=Provider.PROVIDER_AZURE_LOCAL).first() + cls.azure_tenant = cls.azure_provider.account.get("credentials").get("tenant_id") cls.azure_provider_type = Provider.PROVIDER_AZURE_LOCAL diff --git a/koku/subs/test/test_subs_data_messenger.py b/koku/subs/test/test_subs_data_messenger.py index 093d5b01ec..c3d7091f5c 100644 --- a/koku/subs/test/test_subs_data_messenger.py +++ b/koku/subs/test/test_subs_data_messenger.py @@ -2,7 +2,6 @@ # Copyright 2023 Red Hat Inc. # SPDX-License-Identifier: Apache-2.0 # -import json import uuid from collections import defaultdict from unittest.mock import mock_open @@ -29,10 +28,11 @@ def setUpClass(cls): @patch("subs.subs_data_messenger.os.remove") @patch("subs.subs_data_messenger.get_producer") @patch("subs.subs_data_messenger.csv.DictReader") - @patch("subs.subs_data_messenger.SUBSDataMessenger.build_subs_msg") + @patch("subs.subs_data_messenger.SUBSDataMessenger.build_subs_dict") def test_process_and_send_subs_message(self, mock_msg_builder, mock_reader, mock_producer, mock_remove): """Tests that the proper functions are called when running process_and_send_subs_message""" upload_keys = ["fake_key"] + mock_msg_builder.return_value = defaultdict(str) mock_reader.return_value = [ { "subs_start_time": "2023-07-01T01:00:00Z", @@ -54,7 +54,7 @@ def test_process_and_send_subs_message(self, mock_msg_builder, mock_reader, mock mock_msg_builder.assert_called_once() mock_producer.assert_called_once() - def test_build_subs_msg(self): + def test_build_subs_dict(self): """ Test building the kafka message body """ @@ -68,7 +68,7 @@ def test_build_subs_msg(self): sla = "Premium" product_ids = ["479", "70"] static_uuid = uuid.uuid4() - expected_subs_json = { + expected_subs_dict = { "event_id": str(static_uuid), "event_source": "cost-management", "event_type": "snapshot", @@ -88,10 +88,9 @@ def test_build_subs_msg(self): "billing_provider": "aws", "billing_account_id": lineitem_usageaccountid, } - expected = bytes(json.dumps(expected_subs_json), "utf-8") with patch("subs.subs_data_messenger.uuid.uuid4") as mock_uuid: mock_uuid.return_value = static_uuid - actual = self.messenger.build_subs_msg( + actual = self.messenger.build_subs_dict( lineitem_resourceid, lineitem_usageaccountid, lineitem_usagestartdate, @@ -102,7 +101,59 @@ def test_build_subs_msg(self): rol, product_ids, ) - self.assertEqual(expected, actual) + self.assertEqual(expected_subs_dict, actual) + + def test_build_azure_subs_dict(self): + """ + Test building the kafka message body + """ + lineitem_resourceid = "i-55555556" + lineitem_usagestartdate = "2023-07-01T01:00:00Z" + lineitem_usageenddate = "2023-07-01T02:00:00Z" + lineitem_usageaccountid = "9999999999999" + product_vcpu = "2" + usage = "Production" + rol = "Red Hat Enterprise Linux Server" + sla = "Premium" + product_ids = ["479", "70"] + tenant_id = "my-fake-id" + static_uuid = uuid.uuid4() + expected_subs_dict = { + "event_id": str(static_uuid), + "event_source": "cost-management", + "event_type": "snapshot", + "account_number": self.acct, + "org_id": self.org_id, + "service_type": "RHEL System", + "instance_id": lineitem_resourceid, + "timestamp": lineitem_usagestartdate, + "expiration": lineitem_usageenddate, + "measurements": [{"value": product_vcpu, "uom": "vCPUs"}], + "cloud_provider": "AWS", + "hardware_type": "Cloud", + "product_ids": product_ids, + "role": rol, + "sla": sla, + "usage": usage, + "billing_provider": "aws", + "billing_account_id": lineitem_usageaccountid, + "azure_tenant_id": tenant_id, + } + with patch("subs.subs_data_messenger.uuid.uuid4") as mock_uuid: + mock_uuid.return_value = static_uuid + actual = self.messenger.build_azure_subs_dict( + lineitem_resourceid, + lineitem_usageaccountid, + lineitem_usagestartdate, + lineitem_usageenddate, + product_vcpu, + sla, + usage, + rol, + product_ids, + tenant_id, + ) + self.assertEqual(expected_subs_dict, actual) @patch("subs.subs_data_messenger.get_producer") def test_send_kafka_message(self, mock_producer): @@ -111,10 +162,10 @@ def test_send_kafka_message(self, mock_producer): self.messenger.send_kafka_message(kafka_msg) mock_producer.assert_called() - def test_determine_azure_instance_id_tag(self): + def test_determine_azure_instance_and_tenant_id_tag(self): """Test getting the azure instance id from the row provided by a tag returns as expected.""" expected_instance = "waffle-house" - self.messenger.instance_map = {} + self.azure_messenger.instance_map = {} my_row = { "resourceid": "i-55555556", "subs_start_time": "2023-07-01T01:00:00Z", @@ -129,13 +180,16 @@ def test_determine_azure_instance_id_tag(self): "subs_role": "Red Hat Enterprise Linux Server", "subs_product_ids": "479-70", "subs_instance": expected_instance, + "source": self.azure_provider.uuid, } - actual = self.messenger.determine_azure_instance_id(my_row) - self.assertEqual(expected_instance, actual) + actual_instance, actual_tenant = self.azure_messenger.determine_azure_instance_and_tenant_id(my_row) + self.assertEqual(expected_instance, actual_instance) + self.assertEqual(self.azure_tenant, actual_tenant) - def test_determine_azure_instance_id_local_prov(self): + def test_determine_azure_instance_and_tenant_id_local_prov(self): """Test that a local provider does not reach out to Azure.""" - self.messenger.instance_map = {} + self.azure_messenger.instance_map = {} + expected_instance = "" my_row = { "resourceid": "i-55555556", "subs_start_time": "2023-07-01T01:00:00Z", @@ -150,14 +204,17 @@ def test_determine_azure_instance_id_local_prov(self): "subs_role": "Red Hat Enterprise Linux Server", "subs_product_ids": "479-70", "subs_instance": "", + "source": self.azure_provider.uuid, } - actual = self.azure_messenger.determine_azure_instance_id(my_row) - self.assertEqual("", actual) + actual_instance, actual_tenant = self.azure_messenger.determine_azure_instance_and_tenant_id(my_row) + self.assertEqual(expected_instance, actual_instance) + self.assertEqual(self.azure_tenant, actual_tenant) - def test_determine_azure_instance_id_from_map(self): + def test_determine_azure_instance_and_tenant_id_from_map(self): """Test getting the azure instance id from the instance map returns as expected.""" - expected = "oh-yeah" - self.messenger.instance_map["i-55555556"] = expected + expected_instance = "oh-yeah" + expected_tenant = "my-tenant" + self.azure_messenger.instance_map["i-55555556"] = (expected_instance, expected_tenant) my_row = { "resourceid": "i-55555556", "subs_start_time": "2023-07-01T01:00:00Z", @@ -172,13 +229,15 @@ def test_determine_azure_instance_id_from_map(self): "subs_role": "Red Hat Enterprise Linux Server", "subs_product_ids": "479-70", "subs_instance": "fake", + "source": self.azure_provider.uuid, } - actual = self.messenger.determine_azure_instance_id(my_row) - self.assertEqual(expected, actual) + actual_instance, actual_tenant = self.azure_messenger.determine_azure_instance_and_tenant_id(my_row) + self.assertEqual(expected_instance, actual_instance) + self.assertEqual(expected_tenant, actual_tenant) - def test_determine_azure_instance_id(self): + def test_determine_azure_instance_and_tenant_id(self): """Test getting the azure instance id from mock Azure Compute Client returns as expected.""" - expected = "my-fake-id" + expected_instance = "my-fake-id" self.messenger.instance_map = {} my_row = { "resourceid": "i-55555556", @@ -198,21 +257,24 @@ def test_determine_azure_instance_id(self): "resourcegroup": "my-fake-rg", } with patch("subs.subs_data_messenger.AzureClientFactory") as mock_factory: - mock_factory.return_value.compute_client.virtual_machines.get.return_value.vm_id = expected - actual = self.messenger.determine_azure_instance_id(my_row) - self.assertEqual(expected, actual) + mock_factory.return_value.compute_client.virtual_machines.get.return_value.vm_id = expected_instance + actual_instance, actual_tenant = self.messenger.determine_azure_instance_and_tenant_id(my_row) + self.assertEqual(expected_instance, actual_instance) + self.assertEqual(self.azure_tenant, actual_tenant) - @patch("subs.subs_data_messenger.SUBSDataMessenger.determine_azure_instance_id") + @patch("subs.subs_data_messenger.SUBSDataMessenger.determine_azure_instance_and_tenant_id") @patch("subs.subs_data_messenger.os.remove") @patch("subs.subs_data_messenger.get_producer") @patch("subs.subs_data_messenger.csv.DictReader") - @patch("subs.subs_data_messenger.SUBSDataMessenger.build_subs_msg") + @patch("subs.subs_data_messenger.SUBSDataMessenger.build_subs_dict") def test_process_and_send_subs_message_azure_with_id( self, mock_msg_builder, mock_reader, mock_producer, mock_remove, mock_azure_id ): """Tests that the proper functions are called when running process_and_send_subs_message with Azure provider.""" upload_keys = ["fake_key"] self.azure_messenger.date_map = defaultdict(list) + mock_azure_id.return_value = ("string1", "string2") + mock_msg_builder.return_value = defaultdict(str) mock_reader.return_value = [ { "resourceid": "i-55555556", @@ -240,11 +302,11 @@ def test_process_and_send_subs_message_azure_with_id( self.assertEqual(mock_msg_builder.call_count, 4) self.assertEqual(mock_producer.call_count, 4) - @patch("subs.subs_data_messenger.SUBSDataMessenger.determine_azure_instance_id") + @patch("subs.subs_data_messenger.SUBSDataMessenger.determine_azure_instance_and_tenant_id") @patch("subs.subs_data_messenger.os.remove") @patch("subs.subs_data_messenger.get_producer") @patch("subs.subs_data_messenger.csv.DictReader") - @patch("subs.subs_data_messenger.SUBSDataMessenger.build_subs_msg") + @patch("subs.subs_data_messenger.SUBSDataMessenger.build_subs_dict") def test_process_and_send_subs_message_azure_time_already_processed( self, mock_msg_builder, mock_reader, mock_producer, mock_remove, mock_azure_id ):