Skip to content

Commit

Permalink
[COST-4389] Masu endpoint to convert parquet data types (#4837)
Browse files Browse the repository at this point in the history
* [COST-4389] An internal masu endpoint to fix parquet files.

---------

Co-authored-by: Luke Couzens <[email protected]>
Co-authored-by: Sam Doran <[email protected]>
  • Loading branch information
3 people authored Jan 15, 2024
1 parent 85e6f9d commit ecfbf91
Show file tree
Hide file tree
Showing 15 changed files with 1,132 additions and 0 deletions.
Empty file added koku/common/__init__.py
Empty file.
36 changes: 36 additions & 0 deletions koku/common/enum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from enum import Enum


class ReprEnum(Enum):
"""
Only changes the repr(), leaving str() and format() to the mixed-in type.
"""


# StrEnum is available in python 3.11, vendored over from
# https://github.com/python/cpython/blob/c31be58da8577ef140e83d4e46502c7bb1eb9abf/Lib/enum.py#L1321-L1345
class StrEnum(str, ReprEnum): # pragma: no cover
"""
Enum where members are also (and must be) strings
"""

def __new__(cls, *values):
"values must already be of type `str`"
if len(values) > 3:
raise TypeError(f"too many arguments for str(): {values!r}")
if len(values) == 1:
# it must be a string
if not isinstance(values[0], str):
raise TypeError(f"{values[0]!r} is not a string")
if len(values) >= 2:
# check that encoding argument is a string
if not isinstance(values[1], str):
raise TypeError(f"encoding must be a string, not {values[1]!r}")
if len(values) == 3:
# check that errors argument is a string
if not isinstance(values[2], str):
raise TypeError("errors must be a string, not %r" % (values[2]))
value = str(*values)
member = str.__new__(cls, value)
member._value_ = value
return member
2 changes: 2 additions & 0 deletions koku/masu/api/upgrade_trino/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Everything in this directory will become
# dead code after the trino upgrade.
Empty file.
371 changes: 371 additions & 0 deletions koku/masu/api/upgrade_trino/test/test_verify_parquet_files.py

Large diffs are not rendered by default.

65 changes: 65 additions & 0 deletions koku/masu/api/upgrade_trino/test/test_view.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#
# Copyright 2024 Red Hat Inc.
# SPDX-License-Identifier: Apache-2.0
#
"""Test the verify parquet files endpoint view."""
import datetime
from unittest.mock import patch
from uuid import uuid4

from django.test.utils import override_settings
from django.urls import reverse

from api.models import Provider
from masu.api.upgrade_trino.util.task_handler import FixParquetTaskHandler
from masu.processor.tasks import GET_REPORT_FILES_QUEUE
from masu.test import MasuTestCase


@override_settings(ROOT_URLCONF="masu.urls")
class TestUpgradeTrinoView(MasuTestCase):
ENDPOINT = "fix_parquet"
bill_date = datetime.datetime(2024, 1, 1, 0, 0)

@patch("koku.middleware.MASU", return_value=True)
def test_required_parameters_failure(self, _):
"""Test the hcs_report_finalization endpoint."""
parameter_options = [{}, {"start_date": self.bill_date}, {"provider_uuid": self.aws_provider_uuid}]
for parameters in parameter_options:
with self.subTest(parameters=parameters):
response = self.client.get(reverse(self.ENDPOINT), parameters)
self.assertEqual(response.status_code, 400)

@patch("koku.middleware.MASU", return_value=True)
def test_provider_uuid_does_not_exist(self, _):
"""Test the hcs_report_finalization endpoint."""
parameters = {"start_date": self.bill_date, "provider_uuid": str(uuid4())}
response = self.client.get(reverse(self.ENDPOINT), parameters)
self.assertEqual(response.status_code, 400)

@patch("koku.middleware.MASU", return_value=True)
def test_acceptable_parameters(self, _):
"""Test that the endpoint accepts"""
acceptable_parameters = [
{"start_date": self.bill_date, "provider_uuid": self.aws_provider_uuid, "simulate": True},
{"start_date": self.bill_date, "provider_uuid": self.aws_provider_uuid, "simulate": "bad_value"},
{"start_date": self.bill_date, "provider_type": self.aws_provider.type},
]
cleaned_column_mapping = FixParquetTaskHandler.clean_column_names(self.aws_provider.type)
for parameters in acceptable_parameters:
with self.subTest(parameters=parameters):
with patch("masu.celery.tasks.fix_parquet_data_types.apply_async") as patch_celery:
response = self.client.get(reverse(self.ENDPOINT), parameters)
self.assertEqual(response.status_code, 200)
simulate = parameters.get("simulate", False)
if simulate == "bad_value":
simulate = False
async_kwargs = {
"schema_name": self.schema_name,
"provider_type": Provider.PROVIDER_AWS_LOCAL,
"provider_uuid": self.aws_provider.uuid,
"simulate": simulate,
"bill_date": self.bill_date,
"cleaned_column_mapping": cleaned_column_mapping,
}
patch_celery.assert_called_once_with((), async_kwargs, queue=GET_REPORT_FILES_QUEUE)
Empty file.
22 changes: 22 additions & 0 deletions koku/masu/api/upgrade_trino/util/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from common.enum import StrEnum

# Update this to trigger the converter to run again
# even if marked as successful
CONVERTER_VERSION = "1"


class ConversionContextKeys(StrEnum):
metadata = "conversion_metadata"
version = "version"
successful = "successful"
failed_files = "dtype_failed_files"


class ConversionStates(StrEnum):
found_s3_file = "found_s3_file"
downloaded_locally = "downloaded_locally"
no_changes_needed = "no_changes_needed"
coerce_required = "coerce_required"
s3_complete = "sent_to_s3_complete"
s3_failed = "sent_to_s3_failed"
conversion_failed = "failed_data_type_conversion"
140 changes: 140 additions & 0 deletions koku/masu/api/upgrade_trino/util/state_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import logging
from datetime import date

from api.common import log_json
from api.provider.provider_manager import ProviderManager
from api.provider.provider_manager import ProviderManagerError
from masu.api.upgrade_trino.util.constants import ConversionContextKeys
from masu.api.upgrade_trino.util.constants import ConversionStates
from masu.api.upgrade_trino.util.constants import CONVERTER_VERSION


LOG = logging.getLogger(__name__)


class StateTracker:
"""Tracks the state of each s3 file for the provider per bill date"""

def __init__(self, provider_uuid: str, bill_date: date):
self.files = []
self.tracker = {}
self.local_files = {}
self.provider_uuid = provider_uuid
self.bill_date_str = bill_date.strftime("%Y-%m-%d")

def add_to_queue(self, conversion_metadata):
"""
Checks the provider object's metadata to see if we should start the task.
Args:
conversion_metadata (dict): Metadata for the conversion.
Returns:
bool: True if the task should be added to the queue, False otherwise.
"""
bill_metadata = conversion_metadata.get(self.bill_date_str, {})
if bill_metadata.get(ConversionContextKeys.version) != CONVERTER_VERSION:
# always kick off a task if the version does not match or exist.
return True
if bill_metadata.get(ConversionContextKeys.successful):
# if the conversion was successful for this version do not kick
# off a task.
LOG.info(
log_json(
self.provider_uuid,
msg="Conversion already marked as successful",
bill_date=self.bill_date_str,
provider_uuid=self.provider_uuid,
)
)
return False
return True

def set_state(self, s3_obj_key, state):
self.tracker[s3_obj_key] = state

def add_local_file(self, s3_obj_key, local_path):
self.local_files[s3_obj_key] = local_path
self.tracker[s3_obj_key] = ConversionStates.downloaded_locally

def get_files_that_need_updated(self):
"""Returns a mapping of files in the s3 needs
updating state.
{s3_object_key: local_file_path} for
"""
return {
s3_obj_key: self.local_files.get(s3_obj_key)
for s3_obj_key, state in self.tracker.items()
if state == ConversionStates.coerce_required
}

def generate_simulate_messages(self):
"""
Generates the simulate messages.
"""

files_count = 0
files_failed = []
files_need_updated = []
files_correct = []
for s3_obj_key, state in self.tracker.items():
files_count += 1
if state == ConversionStates.coerce_required:
files_need_updated.append(s3_obj_key)
elif state == ConversionStates.no_changes_needed:
files_correct.append(s3_obj_key)
else:
files_failed.append(s3_obj_key)
simulate_info = {
"Files that have all correct data_types.": files_correct,
"Files that need to be updated.": files_need_updated,
"Files that failed to convert.": files_failed,
}
for substring, files_list in simulate_info.items():
LOG.info(
log_json(
self.provider_uuid,
msg=substring,
file_count=len(files_list),
total_count=files_count,
bill_date=self.bill_date_str,
)
)
self._clean_local_files()
return simulate_info

def _clean_local_files(self):
for file_path in self.local_files.values():
file_path.unlink(missing_ok=True)

def _create_bill_date_metadata(self):
# Check for incomplete files
bill_date_data = {"version": CONVERTER_VERSION}
incomplete_files = []
for file_prefix, state in self.tracker.items():
if state not in [ConversionStates.s3_complete, ConversionStates.no_changes_needed]:
file_metadata = {"key": file_prefix, "state": state}
incomplete_files.append(file_metadata)
if incomplete_files:
bill_date_data[ConversionContextKeys.successful] = False
bill_date_data[ConversionContextKeys.failed_files] = incomplete_files
if not incomplete_files:
bill_date_data[ConversionContextKeys.successful] = True
return bill_date_data

def _check_if_complete(self):
try:
manager = ProviderManager(self.provider_uuid)
context = manager.get_additional_context()
conversion_metadata = context.get(ConversionContextKeys.metadata, {})
conversion_metadata[self.bill_date_str] = self._create_bill_date_metadata()
context[ConversionContextKeys.metadata] = conversion_metadata
manager.model.set_additional_context(context)
LOG.info(self.provider_uuid, log_json(msg="setting dtype states", context=context))
except ProviderManagerError:
pass

def finalize_and_clean_up(self):
self._check_if_complete()
self._clean_local_files()
Loading

0 comments on commit ecfbf91

Please sign in to comment.