From ecfbf9186074ed66052f56e3a4c882dbe650f034 Mon Sep 17 00:00:00 2001 From: Cody Myers Date: Mon, 15 Jan 2024 10:53:31 -0500 Subject: [PATCH] [COST-4389] Masu endpoint to convert parquet data types (#4837) * [COST-4389] An internal masu endpoint to fix parquet files. --------- Co-authored-by: Luke Couzens Co-authored-by: Sam Doran --- koku/common/__init__.py | 0 koku/common/enum.py | 36 ++ koku/masu/api/upgrade_trino/__init__.py | 2 + koku/masu/api/upgrade_trino/test/__init__.py | 0 .../test/test_verify_parquet_files.py | 371 ++++++++++++++++++ koku/masu/api/upgrade_trino/test/test_view.py | 65 +++ koku/masu/api/upgrade_trino/util/__init__.py | 0 koku/masu/api/upgrade_trino/util/constants.py | 22 ++ .../api/upgrade_trino/util/state_tracker.py | 140 +++++++ .../api/upgrade_trino/util/task_handler.py | 135 +++++++ .../util/verify_parquet_files.py | 313 +++++++++++++++ koku/masu/api/upgrade_trino/view.py | 37 ++ koku/masu/api/urls.py | 2 + koku/masu/api/views.py | 1 + koku/masu/celery/tasks.py | 8 + 15 files changed, 1132 insertions(+) create mode 100644 koku/common/__init__.py create mode 100644 koku/common/enum.py create mode 100644 koku/masu/api/upgrade_trino/__init__.py create mode 100644 koku/masu/api/upgrade_trino/test/__init__.py create mode 100644 koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py create mode 100644 koku/masu/api/upgrade_trino/test/test_view.py create mode 100644 koku/masu/api/upgrade_trino/util/__init__.py create mode 100644 koku/masu/api/upgrade_trino/util/constants.py create mode 100644 koku/masu/api/upgrade_trino/util/state_tracker.py create mode 100644 koku/masu/api/upgrade_trino/util/task_handler.py create mode 100644 koku/masu/api/upgrade_trino/util/verify_parquet_files.py create mode 100644 koku/masu/api/upgrade_trino/view.py diff --git a/koku/common/__init__.py b/koku/common/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/koku/common/enum.py b/koku/common/enum.py new file mode 100644 index 0000000000..8052fb0c08 --- /dev/null +++ b/koku/common/enum.py @@ -0,0 +1,36 @@ +from enum import Enum + + +class ReprEnum(Enum): + """ + Only changes the repr(), leaving str() and format() to the mixed-in type. + """ + + +# StrEnum is available in python 3.11, vendored over from +# https://github.com/python/cpython/blob/c31be58da8577ef140e83d4e46502c7bb1eb9abf/Lib/enum.py#L1321-L1345 +class StrEnum(str, ReprEnum): # pragma: no cover + """ + Enum where members are also (and must be) strings + """ + + def __new__(cls, *values): + "values must already be of type `str`" + if len(values) > 3: + raise TypeError(f"too many arguments for str(): {values!r}") + if len(values) == 1: + # it must be a string + if not isinstance(values[0], str): + raise TypeError(f"{values[0]!r} is not a string") + if len(values) >= 2: + # check that encoding argument is a string + if not isinstance(values[1], str): + raise TypeError(f"encoding must be a string, not {values[1]!r}") + if len(values) == 3: + # check that errors argument is a string + if not isinstance(values[2], str): + raise TypeError("errors must be a string, not %r" % (values[2])) + value = str(*values) + member = str.__new__(cls, value) + member._value_ = value + return member diff --git a/koku/masu/api/upgrade_trino/__init__.py b/koku/masu/api/upgrade_trino/__init__.py new file mode 100644 index 0000000000..f7b39ea693 --- /dev/null +++ b/koku/masu/api/upgrade_trino/__init__.py @@ -0,0 +1,2 @@ +# Everything in this directory will become +# dead code after the trino upgrade. diff --git a/koku/masu/api/upgrade_trino/test/__init__.py b/koku/masu/api/upgrade_trino/test/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py b/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py new file mode 100644 index 0000000000..913b700f8a --- /dev/null +++ b/koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py @@ -0,0 +1,371 @@ +# +# Copyright 2024 Red Hat Inc. +# SPDX-License-Identifier: Apache-2.0 +# +"""Test the verify parquet files endpoint view.""" +import os +import shutil +import tempfile +from collections import namedtuple +from datetime import datetime +from pathlib import Path +from unittest.mock import patch + +import pandas as pd +import pyarrow as pa +import pyarrow.parquet as pq + +from api.utils import DateHelper +from masu.api.upgrade_trino.util.constants import ConversionContextKeys +from masu.api.upgrade_trino.util.constants import ConversionStates +from masu.api.upgrade_trino.util.constants import CONVERTER_VERSION +from masu.api.upgrade_trino.util.task_handler import FixParquetTaskHandler +from masu.api.upgrade_trino.util.verify_parquet_files import VerifyParquetFiles +from masu.celery.tasks import PROVIDER_REPORT_TYPE_MAP +from masu.config import Config +from masu.test import MasuTestCase +from masu.util.common import get_path_prefix + +DummyS3Object = namedtuple("DummyS3Object", "key") + + +class TestVerifyParquetFiles(MasuTestCase): + def setUp(self): + super().setUp() + # Experienced issues with pyarrow not + # playing nice with tempfiles. Therefore + # I opted for writing files to a tmp dir + self.temp_dir = tempfile.mkdtemp() + self.required_columns = {"float": 0.0, "string": "", "datetime": pd.NaT} + self.expected_pyarrow_dtypes = { + "float": pa.float64(), + "string": pa.string(), + "datetime": pa.timestamp("ms"), + } + self.panda_kwargs = { + "allow_truncated_timestamps": True, + "coerce_timestamps": "ms", + "index": False, + } + self.suffix = ".parquet" + self.bill_date = str(DateHelper().this_month_start) + self.default_provider = self.azure_provider + + def tearDown(self): + shutil.rmtree(self.temp_dir) + + def create_default_verify_handler(self): + return VerifyParquetFiles( + schema_name=self.schema_name, + provider_uuid=str(self.default_provider.uuid), + provider_type=self.default_provider.type, + simulate=True, + bill_date=self.bill_date, + cleaned_column_mapping=self.required_columns, + ) + + def build_expected_additional_context(self, verify_hander, successful=True): + return { + ConversionContextKeys.metadata: { + verify_hander.file_tracker.bill_date_str: { + ConversionContextKeys.version: CONVERTER_VERSION, + ConversionContextKeys.successful: successful, + } + } + } + + def verify_correct_types(self, temp_file, verify_handler): + table = pq.read_table(temp_file) + schema = table.schema + for field in schema: + self.assertEqual(field.type, verify_handler.required_columns.get(field.name)) + + @patch("masu.api.upgrade_trino.util.verify_parquet_files.StateTracker._clean_local_files") + @patch("masu.api.upgrade_trino.util.verify_parquet_files.get_s3_resource") + def test_retrieve_verify_reload_s3_parquet(self, mock_s3_resource, _): + """Test fixes for reindexes on all required columns.""" + # build a parquet file where reindex is used for all required columns + + def create_tmp_test_file(provider, required_columns): + """Creates a parquet file with all empty required columns through reindexing.""" + data_frame = pd.DataFrame() + data_frame = data_frame.reindex(columns=required_columns.keys()) + filename = f"test_{str(provider.uuid)}{self.suffix}" + temp_file = os.path.join(self.temp_dir, filename) + data_frame.to_parquet(temp_file, **self.panda_kwargs) + return temp_file + + attributes = ["aws_provider", "azure_provider", "ocp_provider", "oci_provider"] + for attr in attributes: + with self.subTest(attr=attr): + provider = getattr(self, attr) + required_columns = FixParquetTaskHandler.clean_column_names(provider.type) + temp_file = create_tmp_test_file(provider, required_columns) + mock_bucket = mock_s3_resource.return_value.Bucket.return_value + verify_handler = VerifyParquetFiles( + schema_name=self.schema_name, + provider_uuid=str(provider.uuid), + provider_type=provider.type, + simulate=False, + bill_date=self.bill_date, + cleaned_column_mapping=required_columns, + ) + conversion_metadata = provider.additional_context.get(ConversionContextKeys.metadata, {}) + self.assertTrue(verify_handler.file_tracker.add_to_queue(conversion_metadata)) + prefixes = verify_handler._generate_s3_path_prefixes(DateHelper().this_month_start) + filter_side_effect = [[DummyS3Object(key=temp_file)]] + for _ in range(len(prefixes) - 1): + filter_side_effect.append([]) + mock_bucket.objects.filter.side_effect = filter_side_effect + mock_bucket.download_file.return_value = temp_file + VerifyParquetFiles.local_path = Path(self.temp_dir) + verify_handler.retrieve_verify_reload_s3_parquet() + mock_bucket.upload_fileobj.assert_called() + self.verify_correct_types(temp_file, verify_handler) + # Test that the additional context is set correctly + provider.refresh_from_db() + self.assertEqual( + provider.additional_context, self.build_expected_additional_context(verify_handler, True) + ) + conversion_metadata = provider.additional_context.get(ConversionContextKeys.metadata) + self.assertFalse(verify_handler.file_tracker.add_to_queue(conversion_metadata)) + + def test_double_to_timestamp_transformation_with_reindex(self): + """Test double to datetime transformation with values""" + file_data = { + "float": [1.1, 2.2, 3.3], + "string": ["A", "B", "C"], + "unrequired_column": ["a", "b", "c"], + } + test_file = "transformation_test.parquet" + data_frame = pd.DataFrame(file_data) + data_frame = data_frame.reindex(columns=self.required_columns) + temp_file = os.path.join(self.temp_dir, test_file) + data_frame.to_parquet(temp_file, **self.panda_kwargs) + verify_handler = self.create_default_verify_handler() + verify_handler._perform_transformation_double_to_timestamp(temp_file, ["datetime"]) + self.verify_correct_types(temp_file, verify_handler) + + def test_double_to_timestamp_transformation_with_values(self): + """Test double to datetime transformation with values""" + file_data = { + "float": [1.1, 2.2, 3.3], + "string": ["A", "B", "C"], + "datetime": [1.1, 2.2, 3.3], + "unrequired_column": ["a", "b", "c"], + } + test_file = "transformation_test.parquet" + data_frame = pd.DataFrame(file_data) + data_frame = data_frame.reindex(columns=self.required_columns) + temp_file = os.path.join(self.temp_dir, test_file) + data_frame.to_parquet(temp_file, **self.panda_kwargs) + verify_handler = self.create_default_verify_handler() + verify_handler._perform_transformation_double_to_timestamp(temp_file, ["datetime"]) + self.verify_correct_types(temp_file, verify_handler) + + @patch("masu.api.upgrade_trino.util.verify_parquet_files.StateTracker._clean_local_files") + def test_coerce_parquet_data_type_no_changes_needed(self, _): + """Test a parquet file with correct dtypes.""" + file_data = { + "float": [1.1, 2.2, 3.3], + "string": ["A", "B", "C"], + "datetime": [datetime(2023, 1, 1), datetime(2023, 1, 2), datetime(2023, 1, 3)], + "unrequired_column": ["a", "b", "c"], + } + with tempfile.NamedTemporaryFile(suffix=self.suffix) as temp_file: + pd.DataFrame(file_data).to_parquet(temp_file, **self.panda_kwargs) + verify_handler = self.create_default_verify_handler() + verify_handler.file_tracker.add_local_file(temp_file.name, temp_file) + return_state = verify_handler._coerce_parquet_data_type(temp_file) + verify_handler.file_tracker.set_state(temp_file.name, return_state) + self.assertEqual(return_state, ConversionStates.no_changes_needed) + bill_metadata = verify_handler.file_tracker._create_bill_date_metadata() + self.assertTrue(bill_metadata.get(ConversionContextKeys.successful)) + # Test that generated messages would contain these files. + simulated_messages = verify_handler.file_tracker.generate_simulate_messages() + self.assertIn(str(temp_file.name), simulated_messages.get("Files that have all correct data_types.")) + + def test_coerce_parquet_data_type_coerce_needed(self): + """Test that files created through reindex are fixed correctly.""" + data_frame = pd.DataFrame() + data_frame = data_frame.reindex(columns=self.required_columns.keys()) + filename = f"test{self.suffix}" + temp_file = os.path.join(self.temp_dir, f"test{self.suffix}") + data_frame.to_parquet(temp_file, **self.panda_kwargs) + verify_handler = self.create_default_verify_handler() + verify_handler.file_tracker.add_local_file(filename, Path(temp_file)) + return_state = verify_handler._coerce_parquet_data_type(temp_file) + self.assertEqual(return_state, ConversionStates.coerce_required) + verify_handler.file_tracker.set_state(filename, return_state) + files_need_updating = verify_handler.file_tracker.get_files_that_need_updated() + self.assertTrue(files_need_updating.get(filename)) + self.verify_correct_types(temp_file, verify_handler) + # Test that generated messages would contain these files. + simulated_messages = verify_handler.file_tracker.generate_simulate_messages() + self.assertIn(filename, simulated_messages.get("Files that need to be updated.")) + # Test delete clean local files. + verify_handler.file_tracker._clean_local_files() + self.assertFalse(os.path.exists(temp_file)) + + def test_coerce_parquet_data_type_failed_to_coerce(self): + """Test a parquet file with correct dtypes.""" + file_data = { + "float": [datetime(2023, 1, 1), datetime(2023, 1, 1), datetime(2023, 1, 1)], + "string": ["A", "B", "C"], + "datetime": [datetime(2023, 1, 1), datetime(2023, 1, 2), datetime(2023, 1, 3)], + } + with tempfile.NamedTemporaryFile(suffix=self.suffix) as temp_file: + pd.DataFrame(file_data).to_parquet(temp_file, **self.panda_kwargs) + verify_handler = self.create_default_verify_handler() + verify_handler.file_tracker.add_local_file(temp_file.name, temp_file) + return_state = verify_handler._coerce_parquet_data_type(temp_file) + verify_handler.file_tracker.set_state(temp_file.name, return_state) + self.assertEqual(return_state, ConversionStates.conversion_failed) + verify_handler.file_tracker._check_if_complete() + self.default_provider.refresh_from_db() + conversion_metadata = self.default_provider.additional_context.get(ConversionContextKeys.metadata) + self.assertIsNotNone(conversion_metadata) + bill_metadata = conversion_metadata.get(verify_handler.file_tracker.bill_date_str) + self.assertIsNotNone(bill_metadata) + self.assertFalse(bill_metadata.get(ConversionContextKeys.successful), True) + self.assertIsNotNone(bill_metadata.get(ConversionContextKeys.failed_files)) + # confirm nothing would be sent to s3 + self.assertEqual(verify_handler.file_tracker.get_files_that_need_updated(), {}) + # confirm that it should be retried on next run + self.assertTrue(verify_handler.file_tracker.add_to_queue(conversion_metadata)) + + def test_oci_s3_paths(self): + """test path generation for oci sources.""" + bill_date = DateHelper().this_month_start + expected_s3_paths = [] + for oci_report_type in PROVIDER_REPORT_TYPE_MAP.get(self.oci_provider.type): + path_kwargs = { + "account": self.schema_name, + "provider_type": self.oci_provider.type.replace("-local", ""), + "provider_uuid": self.oci_provider_uuid, + "start_date": bill_date, + "data_type": Config.PARQUET_DATA_TYPE, + "report_type": oci_report_type, + } + expected_s3_paths.append(get_path_prefix(**path_kwargs)) + path_kwargs["daily"] = True + expected_s3_paths.append(get_path_prefix(**path_kwargs)) + verify_handler = VerifyParquetFiles( + schema_name=self.schema_name, + provider_uuid=self.oci_provider_uuid, + provider_type=self.oci_provider.type, + simulate=True, + bill_date=bill_date, + cleaned_column_mapping=self.required_columns, + ) + s3_prefixes = verify_handler._generate_s3_path_prefixes(bill_date) + self.assertEqual(len(s3_prefixes), len(expected_s3_paths)) + for expected_path in expected_s3_paths: + self.assertIn(expected_path, s3_prefixes) + + def test_ocp_s3_paths(self): + """test path generation for ocp sources.""" + bill_date = DateHelper().this_month_start + expected_s3_paths = [] + for ocp_report_type in PROVIDER_REPORT_TYPE_MAP.get(self.ocp_provider.type).keys(): + path_kwargs = { + "account": self.schema_name, + "provider_type": self.ocp_provider.type.replace("-local", ""), + "provider_uuid": self.ocp_provider_uuid, + "start_date": bill_date, + "data_type": Config.PARQUET_DATA_TYPE, + "report_type": ocp_report_type, + } + expected_s3_paths.append(get_path_prefix(**path_kwargs)) + path_kwargs["daily"] = True + expected_s3_paths.append(get_path_prefix(**path_kwargs)) + verify_handler = VerifyParquetFiles( + schema_name=self.schema_name, + provider_uuid=self.ocp_provider_uuid, + provider_type=self.ocp_provider.type, + simulate=True, + bill_date=bill_date, + cleaned_column_mapping=self.required_columns, + ) + s3_prefixes = verify_handler._generate_s3_path_prefixes(bill_date) + self.assertEqual(len(s3_prefixes), len(expected_s3_paths)) + for expected_path in expected_s3_paths: + self.assertIn(expected_path, s3_prefixes) + + def test_other_providers_s3_paths(self): + def _build_expected_s3_paths(metadata): + expected_s3_paths = [] + path_kwargs = { + "account": self.schema_name, + "provider_type": metadata["type"], + "provider_uuid": metadata["uuid"], + "start_date": bill_date, + "data_type": Config.PARQUET_DATA_TYPE, + } + expected_s3_paths.append(get_path_prefix(**path_kwargs)) + path_kwargs["daily"] = True + path_kwargs["report_type"] = "raw" + expected_s3_paths.append(get_path_prefix(**path_kwargs)) + path_kwargs["report_type"] = "openshift" + expected_s3_paths.append(get_path_prefix(**path_kwargs)) + return expected_s3_paths + + bill_date = DateHelper().this_month_start + test_metadata = [ + {"uuid": self.aws_provider_uuid, "type": self.aws_provider.type.replace("-local", "")}, + {"uuid": self.azure_provider_uuid, "type": self.azure_provider.type.replace("-local", "")}, + ] + for metadata in test_metadata: + with self.subTest(metadata=metadata): + expected_s3_paths = _build_expected_s3_paths(metadata) + verify_handler = VerifyParquetFiles( + schema_name=self.schema_name, + provider_uuid=metadata["uuid"], + provider_type=metadata["type"], + simulate=True, + bill_date=bill_date, + cleaned_column_mapping=self.required_columns, + ) + s3_prefixes = verify_handler._generate_s3_path_prefixes(bill_date) + self.assertEqual(len(s3_prefixes), len(expected_s3_paths)) + for expected_path in expected_s3_paths: + self.assertIn(expected_path, s3_prefixes) + + @patch("masu.api.upgrade_trino.util.verify_parquet_files.StateTracker._clean_local_files") + @patch("masu.api.upgrade_trino.util.verify_parquet_files.get_s3_resource") + def test_retrieve_verify_reload_s3_parquet_failure(self, mock_s3_resource, _): + """Test fixes for reindexes on all required columns.""" + # build a parquet file where reindex is used for all required columns + file_data = { + "float": [datetime(2023, 1, 1), datetime(2023, 1, 1), datetime(2023, 1, 1)], + "string": ["A", "B", "C"], + "datetime": [datetime(2023, 1, 1), datetime(2023, 1, 2), datetime(2023, 1, 3)], + } + + bill_date = str(DateHelper().this_month_start) + temp_file = os.path.join(self.temp_dir, f"fail{self.suffix}") + pd.DataFrame(file_data).to_parquet(temp_file, **self.panda_kwargs) + mock_bucket = mock_s3_resource.return_value.Bucket.return_value + verify_handler = VerifyParquetFiles( + schema_name=self.schema_name, + provider_uuid=self.aws_provider_uuid, + provider_type=self.aws_provider.type, + simulate=True, + bill_date=bill_date, + cleaned_column_mapping=self.required_columns, + ) + prefixes = verify_handler._generate_s3_path_prefixes(DateHelper().this_month_start) + filter_side_effect = [[DummyS3Object(key=temp_file)]] + for _ in range(len(prefixes) - 1): + filter_side_effect.append([]) + mock_bucket.objects.filter.side_effect = filter_side_effect + mock_bucket.download_file.return_value = temp_file + VerifyParquetFiles.local_path = Path(self.temp_dir) + verify_handler.retrieve_verify_reload_s3_parquet() + mock_bucket.upload_fileobj.assert_not_called() + os.remove(temp_file) + + def test_local_path(self): + """Test local path.""" + verify_handler = self.create_default_verify_handler() + self.assertTrue(verify_handler.local_path) diff --git a/koku/masu/api/upgrade_trino/test/test_view.py b/koku/masu/api/upgrade_trino/test/test_view.py new file mode 100644 index 0000000000..bbf169921c --- /dev/null +++ b/koku/masu/api/upgrade_trino/test/test_view.py @@ -0,0 +1,65 @@ +# +# Copyright 2024 Red Hat Inc. +# SPDX-License-Identifier: Apache-2.0 +# +"""Test the verify parquet files endpoint view.""" +import datetime +from unittest.mock import patch +from uuid import uuid4 + +from django.test.utils import override_settings +from django.urls import reverse + +from api.models import Provider +from masu.api.upgrade_trino.util.task_handler import FixParquetTaskHandler +from masu.processor.tasks import GET_REPORT_FILES_QUEUE +from masu.test import MasuTestCase + + +@override_settings(ROOT_URLCONF="masu.urls") +class TestUpgradeTrinoView(MasuTestCase): + ENDPOINT = "fix_parquet" + bill_date = datetime.datetime(2024, 1, 1, 0, 0) + + @patch("koku.middleware.MASU", return_value=True) + def test_required_parameters_failure(self, _): + """Test the hcs_report_finalization endpoint.""" + parameter_options = [{}, {"start_date": self.bill_date}, {"provider_uuid": self.aws_provider_uuid}] + for parameters in parameter_options: + with self.subTest(parameters=parameters): + response = self.client.get(reverse(self.ENDPOINT), parameters) + self.assertEqual(response.status_code, 400) + + @patch("koku.middleware.MASU", return_value=True) + def test_provider_uuid_does_not_exist(self, _): + """Test the hcs_report_finalization endpoint.""" + parameters = {"start_date": self.bill_date, "provider_uuid": str(uuid4())} + response = self.client.get(reverse(self.ENDPOINT), parameters) + self.assertEqual(response.status_code, 400) + + @patch("koku.middleware.MASU", return_value=True) + def test_acceptable_parameters(self, _): + """Test that the endpoint accepts""" + acceptable_parameters = [ + {"start_date": self.bill_date, "provider_uuid": self.aws_provider_uuid, "simulate": True}, + {"start_date": self.bill_date, "provider_uuid": self.aws_provider_uuid, "simulate": "bad_value"}, + {"start_date": self.bill_date, "provider_type": self.aws_provider.type}, + ] + cleaned_column_mapping = FixParquetTaskHandler.clean_column_names(self.aws_provider.type) + for parameters in acceptable_parameters: + with self.subTest(parameters=parameters): + with patch("masu.celery.tasks.fix_parquet_data_types.apply_async") as patch_celery: + response = self.client.get(reverse(self.ENDPOINT), parameters) + self.assertEqual(response.status_code, 200) + simulate = parameters.get("simulate", False) + if simulate == "bad_value": + simulate = False + async_kwargs = { + "schema_name": self.schema_name, + "provider_type": Provider.PROVIDER_AWS_LOCAL, + "provider_uuid": self.aws_provider.uuid, + "simulate": simulate, + "bill_date": self.bill_date, + "cleaned_column_mapping": cleaned_column_mapping, + } + patch_celery.assert_called_once_with((), async_kwargs, queue=GET_REPORT_FILES_QUEUE) diff --git a/koku/masu/api/upgrade_trino/util/__init__.py b/koku/masu/api/upgrade_trino/util/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/koku/masu/api/upgrade_trino/util/constants.py b/koku/masu/api/upgrade_trino/util/constants.py new file mode 100644 index 0000000000..fff80b596c --- /dev/null +++ b/koku/masu/api/upgrade_trino/util/constants.py @@ -0,0 +1,22 @@ +from common.enum import StrEnum + +# Update this to trigger the converter to run again +# even if marked as successful +CONVERTER_VERSION = "1" + + +class ConversionContextKeys(StrEnum): + metadata = "conversion_metadata" + version = "version" + successful = "successful" + failed_files = "dtype_failed_files" + + +class ConversionStates(StrEnum): + found_s3_file = "found_s3_file" + downloaded_locally = "downloaded_locally" + no_changes_needed = "no_changes_needed" + coerce_required = "coerce_required" + s3_complete = "sent_to_s3_complete" + s3_failed = "sent_to_s3_failed" + conversion_failed = "failed_data_type_conversion" diff --git a/koku/masu/api/upgrade_trino/util/state_tracker.py b/koku/masu/api/upgrade_trino/util/state_tracker.py new file mode 100644 index 0000000000..0f4288f1f1 --- /dev/null +++ b/koku/masu/api/upgrade_trino/util/state_tracker.py @@ -0,0 +1,140 @@ +import logging +from datetime import date + +from api.common import log_json +from api.provider.provider_manager import ProviderManager +from api.provider.provider_manager import ProviderManagerError +from masu.api.upgrade_trino.util.constants import ConversionContextKeys +from masu.api.upgrade_trino.util.constants import ConversionStates +from masu.api.upgrade_trino.util.constants import CONVERTER_VERSION + + +LOG = logging.getLogger(__name__) + + +class StateTracker: + """Tracks the state of each s3 file for the provider per bill date""" + + def __init__(self, provider_uuid: str, bill_date: date): + self.files = [] + self.tracker = {} + self.local_files = {} + self.provider_uuid = provider_uuid + self.bill_date_str = bill_date.strftime("%Y-%m-%d") + + def add_to_queue(self, conversion_metadata): + """ + Checks the provider object's metadata to see if we should start the task. + + Args: + conversion_metadata (dict): Metadata for the conversion. + + Returns: + bool: True if the task should be added to the queue, False otherwise. + """ + bill_metadata = conversion_metadata.get(self.bill_date_str, {}) + if bill_metadata.get(ConversionContextKeys.version) != CONVERTER_VERSION: + # always kick off a task if the version does not match or exist. + return True + if bill_metadata.get(ConversionContextKeys.successful): + # if the conversion was successful for this version do not kick + # off a task. + LOG.info( + log_json( + self.provider_uuid, + msg="Conversion already marked as successful", + bill_date=self.bill_date_str, + provider_uuid=self.provider_uuid, + ) + ) + return False + return True + + def set_state(self, s3_obj_key, state): + self.tracker[s3_obj_key] = state + + def add_local_file(self, s3_obj_key, local_path): + self.local_files[s3_obj_key] = local_path + self.tracker[s3_obj_key] = ConversionStates.downloaded_locally + + def get_files_that_need_updated(self): + """Returns a mapping of files in the s3 needs + updating state. + + {s3_object_key: local_file_path} for + """ + return { + s3_obj_key: self.local_files.get(s3_obj_key) + for s3_obj_key, state in self.tracker.items() + if state == ConversionStates.coerce_required + } + + def generate_simulate_messages(self): + """ + Generates the simulate messages. + """ + + files_count = 0 + files_failed = [] + files_need_updated = [] + files_correct = [] + for s3_obj_key, state in self.tracker.items(): + files_count += 1 + if state == ConversionStates.coerce_required: + files_need_updated.append(s3_obj_key) + elif state == ConversionStates.no_changes_needed: + files_correct.append(s3_obj_key) + else: + files_failed.append(s3_obj_key) + simulate_info = { + "Files that have all correct data_types.": files_correct, + "Files that need to be updated.": files_need_updated, + "Files that failed to convert.": files_failed, + } + for substring, files_list in simulate_info.items(): + LOG.info( + log_json( + self.provider_uuid, + msg=substring, + file_count=len(files_list), + total_count=files_count, + bill_date=self.bill_date_str, + ) + ) + self._clean_local_files() + return simulate_info + + def _clean_local_files(self): + for file_path in self.local_files.values(): + file_path.unlink(missing_ok=True) + + def _create_bill_date_metadata(self): + # Check for incomplete files + bill_date_data = {"version": CONVERTER_VERSION} + incomplete_files = [] + for file_prefix, state in self.tracker.items(): + if state not in [ConversionStates.s3_complete, ConversionStates.no_changes_needed]: + file_metadata = {"key": file_prefix, "state": state} + incomplete_files.append(file_metadata) + if incomplete_files: + bill_date_data[ConversionContextKeys.successful] = False + bill_date_data[ConversionContextKeys.failed_files] = incomplete_files + if not incomplete_files: + bill_date_data[ConversionContextKeys.successful] = True + return bill_date_data + + def _check_if_complete(self): + try: + manager = ProviderManager(self.provider_uuid) + context = manager.get_additional_context() + conversion_metadata = context.get(ConversionContextKeys.metadata, {}) + conversion_metadata[self.bill_date_str] = self._create_bill_date_metadata() + context[ConversionContextKeys.metadata] = conversion_metadata + manager.model.set_additional_context(context) + LOG.info(self.provider_uuid, log_json(msg="setting dtype states", context=context)) + except ProviderManagerError: + pass + + def finalize_and_clean_up(self): + self._check_if_complete() + self._clean_local_files() diff --git a/koku/masu/api/upgrade_trino/util/task_handler.py b/koku/masu/api/upgrade_trino/util/task_handler.py new file mode 100644 index 0000000000..ed2b5158d1 --- /dev/null +++ b/koku/masu/api/upgrade_trino/util/task_handler.py @@ -0,0 +1,135 @@ +import copy +import logging +from dataclasses import dataclass +from dataclasses import field +from typing import Optional + +from dateutil import parser +from django.http import QueryDict + +from api.common import log_json +from api.provider.models import Provider +from api.utils import DateHelper +from masu.api.upgrade_trino.util.constants import ConversionContextKeys +from masu.api.upgrade_trino.util.state_tracker import StateTracker +from masu.celery.tasks import fix_parquet_data_types +from masu.processor import is_customer_large +from masu.processor.tasks import GET_REPORT_FILES_QUEUE +from masu.processor.tasks import GET_REPORT_FILES_QUEUE_XL +from masu.util.common import strip_characters_from_column_name +from reporting.provider.aws.models import TRINO_REQUIRED_COLUMNS as AWS_TRINO_REQUIRED_COLUMNS +from reporting.provider.azure.models import TRINO_REQUIRED_COLUMNS as AZURE_TRINO_REQUIRED_COLUMNS +from reporting.provider.oci.models import TRINO_REQUIRED_COLUMNS as OCI_TRINO_REQUIRED_COLUMNS + +LOG = logging.getLogger(__name__) + + +class RequiredParametersError(Exception): + """Handle require parameters error.""" + + +@dataclass(frozen=True) +class FixParquetTaskHandler: + start_date: Optional[str] = field(default=None) + provider_uuid: Optional[str] = field(default=None) + provider_type: Optional[str] = field(default=None) + simulate: Optional[bool] = field(default=False) + cleaned_column_mapping: Optional[dict] = field(default=None) + + # Node role is the only column we add manually for OCP + # Therefore, it is the only column that can be incorrect + REQUIRED_COLUMNS_MAPPING = { + Provider.PROVIDER_OCI: OCI_TRINO_REQUIRED_COLUMNS, + Provider.PROVIDER_OCP: {"node_role": ""}, + Provider.PROVIDER_AWS: AWS_TRINO_REQUIRED_COLUMNS, + Provider.PROVIDER_AZURE: AZURE_TRINO_REQUIRED_COLUMNS, + } + + @classmethod + def from_query_params(cls, query_params: QueryDict) -> "FixParquetTaskHandler": + """Create an instance from query parameters.""" + kwargs = {} + if start_date := query_params.get("start_date"): + if start_date: + kwargs["start_date"] = parser.parse(start_date).replace(day=1) + + if provider_uuid := query_params.get("provider_uuid"): + provider = Provider.objects.filter(uuid=provider_uuid).first() + if not provider: + raise RequiredParametersError(f"The provider_uuid {provider_uuid} does not exist.") + kwargs["provider_uuid"] = provider_uuid + kwargs["provider_type"] = provider.type + + if provider_type := query_params.get("provider_type"): + kwargs["provider_type"] = provider_type + + if simulate := query_params.get("simulate"): + if simulate.lower() == "true": + kwargs["simulate"] = True + + if not kwargs.get("provider_type") and not kwargs.get("provider_uuid"): + raise RequiredParametersError("provider_uuid or provider_type must be supplied") + + if not kwargs.get("start_date"): + raise RequiredParametersError("start_date must be supplied as a parameter.") + + kwargs["cleaned_column_mapping"] = cls.clean_column_names(kwargs["provider_type"]) + return cls(**kwargs) + + @classmethod + def clean_column_names(cls, provider_type): + """Creates a mapping of columns to expected pyarrow values.""" + clean_column_names = {} + provider_mapping = cls.REQUIRED_COLUMNS_MAPPING.get(provider_type.replace("-local", "")) + # Our required mapping stores the raw column name; however, + # the parquet files will contain the cleaned column name. + for raw_col, default_val in provider_mapping.items(): + clean_column_names[strip_characters_from_column_name(raw_col)] = default_val + return clean_column_names + + def build_celery_tasks(self): + """ + Fixes the parquet file data type for each account. + Args: + simulate (Boolean) simulate the parquet file fixing. + Returns: + (celery.result.AsyncResult) Async result for deletion request. + """ + async_results = [] + if self.provider_uuid: + providers = Provider.objects.filter(uuid=self.provider_uuid) + else: + providers = Provider.objects.filter(active=True, paused=False, type=self.provider_type) + + for provider in providers: + queue_name = GET_REPORT_FILES_QUEUE + if is_customer_large(provider.account["schema_name"]): + queue_name = GET_REPORT_FILES_QUEUE_XL + + account = copy.deepcopy(provider.account) + conversion_metadata = provider.additional_context.get(ConversionContextKeys.metadata, {}) + dh = DateHelper() + bill_datetimes = dh.list_months(self.start_date, dh.today.replace(tzinfo=None)) + for bill_date in bill_datetimes: + tracker = StateTracker(self.provider_uuid, bill_date) + if tracker.add_to_queue(conversion_metadata): + async_result = fix_parquet_data_types.s( + schema_name=account.get("schema_name"), + provider_type=account.get("provider_type"), + provider_uuid=account.get("provider_uuid"), + simulate=self.simulate, + bill_date=bill_date, + cleaned_column_mapping=self.cleaned_column_mapping, + ).apply_async(queue=queue_name) + LOG.info( + log_json( + provider.uuid, + msg="Calling fix_parquet_data_types", + schema=account.get("schema_name"), + provider_uuid=provider.uuid, + task_id=str(async_result), + bill_date=bill_date, + ) + ) + async_results.append(str(async_result)) + return async_results diff --git a/koku/masu/api/upgrade_trino/util/verify_parquet_files.py b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py new file mode 100644 index 0000000000..71575642a7 --- /dev/null +++ b/koku/masu/api/upgrade_trino/util/verify_parquet_files.py @@ -0,0 +1,313 @@ +import logging +import os +import uuid +from pathlib import Path + +import ciso8601 +import pyarrow as pa +import pyarrow.parquet as pq +from botocore.exceptions import ClientError +from django.conf import settings +from django_tenants.utils import schema_context + +from api.common import log_json +from api.provider.models import Provider +from masu.api.upgrade_trino.util.constants import ConversionStates +from masu.api.upgrade_trino.util.constants import CONVERTER_VERSION +from masu.api.upgrade_trino.util.state_tracker import StateTracker +from masu.config import Config +from masu.processor.parquet.parquet_report_processor import OPENSHIFT_REPORT_TYPE +from masu.util.aws.common import get_s3_resource +from masu.util.common import get_path_prefix + + +LOG = logging.getLogger(__name__) + + +class VerifyParquetFiles: + S3_OBJ_LOG_KEY = "s3_object_key" + S3_PREFIX_LOG_KEY = "s3_prefix" + + def __init__(self, schema_name, provider_uuid, provider_type, simulate, bill_date, cleaned_column_mapping): + self.schema_name = schema_name + self.provider_uuid = uuid.UUID(provider_uuid) + self.provider_type = provider_type.replace("-local", "") + self.simulate = simulate + self.bill_date = self._bill_date(bill_date) + self.file_tracker = StateTracker(provider_uuid, self.bill_date) + self.report_types = self._set_report_types() + self.required_columns = self._set_pyarrow_types(cleaned_column_mapping) + self.logging_context = { + "provider_type": self.provider_type, + "provider_uuid": self.provider_uuid, + "schema": self.schema_name, + "simulate": self.simulate, + "bill_date": self.bill_date, + } + + def _bill_date(self, bill_date): + """bill_date""" + if isinstance(bill_date, str): + return ciso8601.parse_datetime(bill_date).replace(tzinfo=None).date() + return bill_date + + def _set_pyarrow_types(self, cleaned_column_mapping): + mapping = {} + for key, default_val in cleaned_column_mapping.items(): + if str(default_val) == "NaT": + # Store original provider datetime type + if self.provider_type == "Azure": + mapping[key] = pa.timestamp("ms") + else: + mapping[key] = pa.timestamp("ms", tz="UTC") + elif isinstance(default_val, str): + mapping[key] = pa.string() + elif isinstance(default_val, float): + mapping[key] = pa.float64() + return mapping + + def _set_report_types(self): + if self.provider_type == Provider.PROVIDER_OCI: + return ["cost", "usage"] + if self.provider_type == Provider.PROVIDER_OCP: + return ["namespace_labels", "node_labels", "pod_usage", "storage_usage"] + return [None] + + def _parquet_path_s3(self, bill_date, report_type): + """The path in the S3 bucket where Parquet files are loaded.""" + return get_path_prefix( + self.schema_name, + self.provider_type, + self.provider_uuid, + bill_date, + Config.PARQUET_DATA_TYPE, + report_type=report_type, + ) + + def _parquet_daily_path_s3(self, bill_date, report_type): + """The path in the S3 bucket where Parquet files are loaded.""" + if report_type is None: + report_type = "raw" + return get_path_prefix( + self.schema_name, + self.provider_type, + self.provider_uuid, + bill_date, + Config.PARQUET_DATA_TYPE, + report_type=report_type, + daily=True, + ) + + def _parquet_ocp_on_cloud_path_s3(self, bill_date): + """The path in the S3 bucket where Parquet files are loaded.""" + return get_path_prefix( + self.schema_name, + self.provider_type, + self.provider_uuid, + bill_date, + Config.PARQUET_DATA_TYPE, + report_type=OPENSHIFT_REPORT_TYPE, + daily=True, + ) + + def _generate_s3_path_prefixes(self, bill_date): + """ + generates the s3 path prefixes. + """ + with schema_context(self.schema_name): + ocp_on_cloud_check = Provider.objects.filter( + infrastructure__infrastructure_provider__uuid=self.provider_uuid + ).exists() + path_prefixes = set() + for report_type in self.report_types: + path_prefixes.add(self._parquet_path_s3(bill_date, report_type)) + path_prefixes.add(self._parquet_daily_path_s3(bill_date, report_type)) + if ocp_on_cloud_check: + path_prefixes.add(self._parquet_ocp_on_cloud_path_s3(bill_date)) + return path_prefixes + + @property + def local_path(self): + local_path = Path(Config.TMP_DIR, self.schema_name, str(self.provider_uuid)) + local_path.mkdir(parents=True, exist_ok=True) + return local_path + + def retrieve_verify_reload_s3_parquet(self): + """Retrieves the s3 files from s3""" + s3_resource = get_s3_resource(settings.S3_ACCESS_KEY, settings.S3_SECRET, settings.S3_REGION) + s3_bucket = s3_resource.Bucket(settings.S3_BUCKET_NAME) + for prefix in self._generate_s3_path_prefixes(self.bill_date): + self.logging_context[self.S3_PREFIX_LOG_KEY] = prefix + LOG.info( + log_json( + self.provider_uuid, + msg="Retrieving files from S3.", + context=self.logging_context, + prefix=prefix, + ) + ) + for s3_object in s3_bucket.objects.filter(Prefix=prefix): + s3_object_key = s3_object.key + self.logging_context[self.S3_OBJ_LOG_KEY] = s3_object_key + self.file_tracker.set_state(s3_object_key, ConversionStates.found_s3_file) + local_file_path = self.local_path.joinpath(os.path.basename(s3_object_key)) + LOG.info( + log_json( + self.provider_uuid, + msg="Downloading file locally", + context=self.logging_context, + ) + ) + s3_bucket.download_file(s3_object_key, local_file_path) + self.file_tracker.add_local_file(s3_object_key, local_file_path) + self.file_tracker.set_state(s3_object_key, self._coerce_parquet_data_type(local_file_path)) + del self.logging_context[self.S3_OBJ_LOG_KEY] + del self.logging_context[self.S3_PREFIX_LOG_KEY] + + if self.simulate: + self.file_tracker.generate_simulate_messages() + return False + else: + files_need_updated = self.file_tracker.get_files_that_need_updated() + for s3_obj_key, converted_local_file_path in files_need_updated.items(): + self.logging_context[self.S3_OBJ_LOG_KEY] = s3_obj_key + # Overwrite s3 object with updated file data + with open(converted_local_file_path, "rb") as new_file: + LOG.info( + log_json( + self.provider_uuid, + msg="Uploading revised parquet file.", + context=self.logging_context, + local_file_path=converted_local_file_path, + ) + ) + try: + s3_bucket.upload_fileobj( + new_file, + s3_obj_key, + ExtraArgs={"Metadata": {"converter_version": CONVERTER_VERSION}}, + ) + self.file_tracker.set_state(s3_obj_key, ConversionStates.s3_complete) + except ClientError as e: + LOG.info(f"Failed to overwrite S3 file {s3_object_key}: {str(e)}") + self.file_tracker.set_state(s3_object_key, ConversionStates.s3_failed) + continue + self.file_tracker.finalize_and_clean_up() + + def _perform_transformation_double_to_timestamp(self, parquet_file_path, field_names): + """Performs a transformation to change a double to a timestamp.""" + if not field_names: + return + LOG.info( + log_json( + self.provider_uuid, + msg="Transforming fields from double to timestamp.", + context=self.logging_context, + local_file_path=parquet_file_path, + updated_columns=field_names, + ) + ) + table = pq.read_table(parquet_file_path) + schema = table.schema + fields = [] + for field in schema: + if field.name in field_names: + replacement_value = [] + correct_data_type = self.required_columns.get(field.name) + corrected_column = pa.array(replacement_value, type=correct_data_type) + field = pa.field(field.name, corrected_column.type) + fields.append(field) + # Create a new schema + new_schema = pa.schema(fields) + # Create a DataFrame from the original PyArrow Table + original_df = table.to_pandas() + + # Update the DataFrame with corrected values + for field_name in field_names: + if field_name in original_df.columns: + original_df[field_name] = corrected_column.to_pandas() + + # Create a new PyArrow Table from the updated DataFrame + new_table = pa.Table.from_pandas(original_df, schema=new_schema) + + # Write the new table back to the Parquet file + pq.write_table(new_table, parquet_file_path) + + # Same logic as last time, but combined into one method & added state tracking + def _coerce_parquet_data_type(self, parquet_file_path): + """If a parquet file has an incorrect dtype we can attempt to coerce + it to the correct type it. + + Returns a boolean indicating if the update parquet file should be sent + to s3. + """ + LOG.info( + log_json( + self.provider_uuid, + msg="Checking local parquet_file", + context=self.logging_context, + local_file_path=parquet_file_path, + ) + ) + corrected_fields = {} + double_to_timestamp_fields = [] + try: + table = pq.read_table(parquet_file_path) + schema = table.schema + fields = [] + for field in schema: + if correct_data_type := self.required_columns.get(field.name): + # Check if the field's type matches the desired type + if field.type != correct_data_type: + LOG.info( + log_json( + self.provider_uuid, + msg="Incorrect data type, building new schema.", + context=self.logging_context, + column_name=field.name, + current_dtype=field.type, + expected_data_type=correct_data_type, + ) + ) + if field.type == pa.float64() and correct_data_type in [ + pa.timestamp("ms"), + pa.timestamp("ms", tz="UTC"), + ]: + double_to_timestamp_fields.append(field.name) + else: + field = pa.field(field.name, correct_data_type) + corrected_fields[field.name] = correct_data_type + fields.append(field) + + if not corrected_fields and not double_to_timestamp_fields: + # Final State: No changes needed. + LOG.info( + log_json( + self.provider_uuid, + msg="All data types correct", + context=self.logging_context, + local_file_path=parquet_file_path, + ) + ) + return ConversionStates.no_changes_needed + + new_schema = pa.schema(fields) + LOG.info( + log_json( + self.provider_uuid, + msg="Applying new parquet schema to local parquet file.", + context=self.logging_context, + local_file_path=parquet_file_path, + updated_columns=corrected_fields, + ) + ) + table = table.cast(new_schema) + # Write the table back to the Parquet file + pa.parquet.write_table(table, parquet_file_path) + self._perform_transformation_double_to_timestamp(parquet_file_path, double_to_timestamp_fields) + # Signal that we need to send this update to S3. + return ConversionStates.coerce_required + + except Exception as e: + LOG.info(log_json(self.provider_uuid, msg="Failed to coerce data.", context=self.logging_context, error=e)) + return ConversionStates.conversion_failed diff --git a/koku/masu/api/upgrade_trino/view.py b/koku/masu/api/upgrade_trino/view.py new file mode 100644 index 0000000000..ae1524abc7 --- /dev/null +++ b/koku/masu/api/upgrade_trino/view.py @@ -0,0 +1,37 @@ +# +# Copyright 2024 Red Hat Inc. +# SPDX-License-Identifier: Apache-2.0 +# +"""View for fixing parquet files endpoint.""" +import logging + +from django.views.decorators.cache import never_cache +from rest_framework import status +from rest_framework.decorators import api_view +from rest_framework.decorators import permission_classes +from rest_framework.decorators import renderer_classes +from rest_framework.permissions import AllowAny +from rest_framework.response import Response +from rest_framework.settings import api_settings + +from masu.api.upgrade_trino.util.task_handler import FixParquetTaskHandler +from masu.api.upgrade_trino.util.task_handler import RequiredParametersError + +LOG = logging.getLogger(__name__) + + +@never_cache +@api_view(http_method_names=["GET", "DELETE"]) +@permission_classes((AllowAny,)) +@renderer_classes(tuple(api_settings.DEFAULT_RENDERER_CLASSES)) +def fix_parquet(request): + """Fix parquet files so that we can upgrade Trino.""" + try: + task_handler = FixParquetTaskHandler.from_query_params(request.query_params) + async_fix_results = task_handler.build_celery_tasks() + except RequiredParametersError as errmsg: + return Response({"Error": str(errmsg)}, status=status.HTTP_400_BAD_REQUEST) + response_key = "Async jobs for fix parquet files" + if task_handler.simulate: + response_key = response_key + " (simulated)" + return Response({response_key: str(async_fix_results)}) diff --git a/koku/masu/api/urls.py b/koku/masu/api/urls.py index 54513f4575..e9d7ec1510 100644 --- a/koku/masu/api/urls.py +++ b/koku/masu/api/urls.py @@ -23,6 +23,7 @@ from masu.api.views import EnabledTagView from masu.api.views import expired_data from masu.api.views import explain_query +from masu.api.views import fix_parquet from masu.api.views import get_status from masu.api.views import hcs_report_data from masu.api.views import hcs_report_finalization @@ -48,6 +49,7 @@ urlpatterns = [ + path("fix_parquet/", fix_parquet, name="fix_parquet"), path("status/", get_status, name="server-status"), path("download/", download_report, name="report_download"), path("ingress_reports/", ingress_reports, name="ingress_reports"), diff --git a/koku/masu/api/views.py b/koku/masu/api/views.py index 124c26e00f..aa97e2b284 100644 --- a/koku/masu/api/views.py +++ b/koku/masu/api/views.py @@ -39,3 +39,4 @@ from masu.api.update_cost_model_costs import update_cost_model_costs from masu.api.update_exchange_rates import update_exchange_rates from masu.api.update_openshift_on_cloud import update_openshift_on_cloud +from masu.api.upgrade_trino.view import fix_parquet diff --git a/koku/masu/celery/tasks.py b/koku/masu/celery/tasks.py index 3e0159c4d2..38ae083f4a 100644 --- a/koku/masu/celery/tasks.py +++ b/koku/masu/celery/tasks.py @@ -30,6 +30,7 @@ from api.utils import DateHelper from koku import celery_app from koku.notifications import NotificationService +from masu.api.upgrade_trino.util.verify_parquet_files import VerifyParquetFiles from masu.config import Config from masu.database.cost_model_db_accessor import CostModelDBAccessor from masu.database.ocp_report_db_accessor import OCPReportDBAccessor @@ -39,6 +40,7 @@ from masu.processor.orchestrator import Orchestrator from masu.processor.tasks import autovacuum_tune_schema from masu.processor.tasks import DEFAULT +from masu.processor.tasks import GET_REPORT_FILES_QUEUE from masu.processor.tasks import PRIORITY_QUEUE from masu.processor.tasks import REMOVE_EXPIRED_DATA_QUEUE from masu.prometheus_stats import QUEUES @@ -57,6 +59,12 @@ } +@celery_app.task(name="masu.celery.tasks.fix_parquet_data_types", queue=GET_REPORT_FILES_QUEUE) +def fix_parquet_data_types(*args, **kwargs): + verify_parquet = VerifyParquetFiles(*args, **kwargs) + verify_parquet.retrieve_verify_reload_s3_parquet() + + @celery_app.task(name="masu.celery.tasks.check_report_updates", queue=DEFAULT) def check_report_updates(*args, **kwargs): """Scheduled task to initiate scanning process on a regular interval."""