From ba6fbbf4ab933d9d30fd7c8dd5689801d1bc283a Mon Sep 17 00:00:00 2001 From: Muhammad Afaq Shuaib Date: Tue, 31 Dec 2024 14:59:33 +0500 Subject: [PATCH] refactor: csv loader code implementation --- .../data_loaders/csv_loader.py | 387 ++++++++++++------ 1 file changed, 259 insertions(+), 128 deletions(-) diff --git a/course_discovery/apps/course_metadata/data_loaders/csv_loader.py b/course_discovery/apps/course_metadata/data_loaders/csv_loader.py index 89feee5078..f240b9646d 100644 --- a/course_discovery/apps/course_metadata/data_loaders/csv_loader.py +++ b/course_discovery/apps/course_metadata/data_loaders/csv_loader.py @@ -4,6 +4,7 @@ """ import csv import logging +from functools import cache import unicodecsv from django.conf import settings @@ -75,8 +76,17 @@ def __init__( * product_source: slug of the external source that actually owns the product. """ super().__init__(partner, api_url, max_workers, is_threadsafe) - self.error_logs = {} - self.ingestion_summary = { + self.error_logs = {key: [] for key in CSV_LOADER_ERROR_LOG_SEQUENCE} + self.ingestion_summary = self._initialize_ingestion_summary() + self.course_uuids = {} # to show the discovery course ids for each processed course + self.product_type = product_type + self.product_source = self._get_product_source(product_source) + self.reader = self._initialize_csv_reader(csv_path, csv_file, use_gspread_client) + self.ingestion_summary['total_products_count'] = len(self.reader) + + def _initialize_ingestion_summary(self): + """Initialize the ingestion summary dictionary.""" + return { 'total_products_count': 0, 'success_count': 0, 'failure_count': 0, @@ -84,36 +94,35 @@ def __init__( 'created_products': [], 'archived_products': [] } - self.course_uuids = {} # to show the discovery course ids for each processed course - self.product_type = product_type + + def _get_product_source(self, product_source): + """ + Retrieve the product source or raise an exception if product source doesn't exist already + """ try: - self.product_source = Source.objects.get(slug=product_source) + return Source.objects.get(slug=product_source) except Source.DoesNotExist: - logger.exception(f"Unable to locate source with slug {product_source}") + logger.exception(f"Unable to locate source with slug '{product_source}'") raise - for error_log_key in CSV_LOADER_ERROR_LOG_SEQUENCE: - self.error_logs.setdefault(error_log_key, []) - + def _initialize_csv_reader(self, csv_path, csv_file, use_gspread_client): + """ + Initialize the CSV reader based on the input source (csv_path, csv_file or gspread_client) + """ try: if use_gspread_client: - # TODO: add unit tests - product_type_config = settings.PRODUCT_METADATA_MAPPING[product_type][self.product_source.slug] + product_type_config = settings.PRODUCT_METADATA_MAPPING[self.product_type][self.product_source.slug] gspread_client = GspreadClient() - self.reader = gspread_client.read_data(product_type_config) + return list(gspread_client.read_data(product_type_config)) else: - # Read file from the path if given. Otherwise, read from the file - # received from CSVDataLoaderConfiguration. - self.reader = csv.DictReader(open(csv_path, 'r')) if csv_path \ - else list(unicodecsv.DictReader(csv_file)) # lint-amnesty, pylint: disable=consider-using-with + # read the file from the provided path; otherwise, use the file received from CSVDataLoaderConfiguration + return list(csv.DictReader(open(csv_path, 'r'))) if csv_path else list(unicodecsv.DictReader(csv_file)) except FileNotFoundError: - logger.exception("Error opening csv file at path {}".format(csv_path)) # lint-amnesty, pylint: disable=logging-format-interpolation - raise # re-raising exception to avoid moving the code flow - except Exception: - logger.exception("Error reading the input data source") - raise # re-raising exception to avoid moving the code flow - self.reader = list(self.reader) - self.ingestion_summary['total_products_count'] = len(self.reader) + logger.exception(f"Error opening CSV file at path: {csv_path}") + raise + except Exception as e: + logger.exception(f"Error reading input data source: {e}") + raise def ingest(self): # pylint: disable=too-many-statements logger.info("Initiating CSV data loader flow.") @@ -128,41 +137,9 @@ def ingest(self): # pylint: disable=too-many-statements if 'external_identifier' in row: course_external_identifiers.add(row['external_identifier']) - logger.info('Starting data import flow for {}'.format(course_title)) # lint-amnesty, pylint: disable=logging-format-interpolation - if not Organization.objects.filter(key=org_key).exists(): - error_message = CSVIngestionErrorMessages.MISSING_ORGANIZATION.format( - org_key=org_key, - course_title=course_title, - ) - logger.error(error_message) - self._register_ingestion_error(CSVIngestionErrors.MISSING_ORGANIZATION, error_message) - continue - - try: - course_type = CourseType.objects.get(name=row['course_enrollment_track']) - course_run_type = CourseRunType.objects.get(name=row['course_run_enrollment_track']) - except CourseType.DoesNotExist: - error_message = CSVIngestionErrorMessages.MISSING_COURSE_TYPE.format( - course_title=course_title, course_type=row['course_enrollment_track'] - ) - logger.exception(error_message) - self._register_ingestion_error(CSVIngestionErrors.MISSING_COURSE_TYPE, error_message) - continue - except CourseRunType.DoesNotExist: - error_message = CSVIngestionErrorMessages.MISSING_COURSE_RUN_TYPE.format( - course_title=course_title, course_run_type=row['course_run_enrollment_track'] - ) - logger.exception(error_message) - self._register_ingestion_error(CSVIngestionErrors.MISSING_COURSE_RUN_TYPE, error_message) - continue - - missing_fields = self.validate_course_data(course_type, row) - if missing_fields: - error_message = CSVIngestionErrorMessages.MISSING_REQUIRED_DATA.format( - course_title=course_title, missing_data=missing_fields - ) - logger.error(error_message) - self._register_ingestion_error(CSVIngestionErrors.MISSING_REQUIRED_DATA, error_message) + logger.info(f'Starting data import flow for {course_title}') + is_valid, course_type, course_run_type = self._validate_and_process_row(row, course_title, org_key) + if not is_valid: continue course_key = self.get_course_key(org_key, row['number']) @@ -170,16 +147,12 @@ def ingest(self): # pylint: disable=too-many-statements is_course_already_ingested = bool(course) and str(course.uuid) in self.course_uuids is_course_created = False is_course_run_created = False - course_run_restriction = ( - None - if row.get('restriction_type', None) == 'None' - else row.get('restriction_type', None) - ) + course_run_restriction = self._get_course_run_restriction(row) is_future_variant = row.get('is_future_variant') == 'True' if course: try: - logger.info("Course {} is located in the database.".format(course_key)) # lint-amnesty, pylint: disable=logging-format-interpolation + logger.info(f"Course {course_key} is located in the database.") course_run, is_course_run_created = self._get_or_create_course_run( row, course, course_type, course_run_type.uuid ) @@ -187,19 +160,20 @@ def ingest(self): # pylint: disable=too-many-statements logger.exception(exc) continue else: - logger.info("Course key {} could not be found in database, creating the course.".format(course_key)) # lint-amnesty, pylint: disable=logging-format-interpolation + logger.info(f"Course key {course_key} could not be found in database, creating the course.") try: _ = self._create_course(row, course_type, course_run_type.uuid) except Exception as exc: # pylint: disable=broad-except exception_message = exc if hasattr(exc, 'response'): exception_message = exc.response.content.decode('utf-8') - error_message = CSVIngestionErrorMessages.COURSE_CREATE_ERROR.format( - course_title=course_title, - exception_message=exception_message + self._log_ingestion_error( + CSVIngestionErrors.COURSE_CREATE_ERROR, + CSVIngestionErrorMessages.COURSE_CREATE_ERROR.format( + course_title=course_title, + exception_message=exception_message + ) ) - logger.exception(error_message) - self._register_ingestion_error(CSVIngestionErrors.COURSE_CREATE_ERROR, error_message) continue course = Course.everything.select_related('type').get(key=course_key, partner=self.partner) @@ -216,9 +190,10 @@ def ingest(self): # pylint: disable=too-many-statements row['image'], headers=self.REQUEST_USER_AGENT_HEADERS) if not is_downloaded: - error_message = CSVIngestionErrorMessages.IMAGE_DOWNLOAD_FAILURE.format(course_title=course_title) - logger.error(error_message) - self._register_ingestion_error(CSVIngestionErrors.IMAGE_DOWNLOAD_FAILURE, error_message) + self._log_ingestion_error( + CSVIngestionErrors.IMAGE_DOWNLOAD_FAILURE, + CSVIngestionErrorMessages.IMAGE_DOWNLOAD_FAILURE.format(course_title=course_title) + ) continue if not is_course_created: self.add_product_source(course) @@ -229,12 +204,12 @@ def ingest(self): # pylint: disable=too-many-statements exception_message = exc if hasattr(exc, 'response'): exception_message = exc.response.content.decode('utf-8') - error_message = CSVIngestionErrorMessages.COURSE_UPDATE_ERROR.format( - course_title=course_title, - exception_message=exception_message + self._log_ingestion_error( + CSVIngestionErrors.COURSE_UPDATE_ERROR, + CSVIngestionErrorMessages.COURSE_UPDATE_ERROR.format( + course_title=course_title, exception_message=exception_message + ) ) - logger.exception(error_message) - self._register_ingestion_error(CSVIngestionErrors.COURSE_UPDATE_ERROR, error_message) continue if row.get('organization_logo_override'): @@ -246,11 +221,12 @@ def ingest(self): # pylint: disable=too-many-statements headers=self.REQUEST_USER_AGENT_HEADERS ) if not is_logo_downloaded: - error_message = CSVIngestionErrorMessages.LOGO_IMAGE_DOWNLOAD_FAILURE.format( - course_title=course_title + self._log_ingestion_error( + CSVIngestionErrors.LOGO_IMAGE_DOWNLOAD_FAILURE, + CSVIngestionErrorMessages.LOGO_IMAGE_DOWNLOAD_FAILURE.format( + course_title=course_title + ) ) - logger.error(error_message) - self._register_ingestion_error(CSVIngestionErrors.LOGO_IMAGE_DOWNLOAD_FAILURE, error_message) else: try: @@ -260,12 +236,12 @@ def ingest(self): # pylint: disable=too-many-statements except Exception as exc: # pylint: disable=broad-except exception_message = exc if hasattr(exc, 'response'): - error_message = CSVIngestionErrorMessages.COURSE_ENTITLEMENT_PRICE_UPDATE_ERROR.format( - course_title=course_title, - exception_message=exception_message + self._log_ingestion_error( + CSVIngestionErrors.COURSE_UPDATE_ERROR, + CSVIngestionErrorMessages.COURSE_ENTITLEMENT_PRICE_UPDATE_ERROR.format( + course_title=course_title, exception_message=exception_message + ) ) - logger.exception(error_message) - self._register_ingestion_error(CSVIngestionErrors.COURSE_UPDATE_ERROR, error_message) continue # No need to update the course run if the run is already in the review @@ -276,12 +252,12 @@ def ingest(self): # pylint: disable=too-many-statements exception_message = exc if hasattr(exc, 'response'): exception_message = exc.response.content.decode('utf-8') - error_message = CSVIngestionErrorMessages.COURSE_RUN_UPDATE_ERROR.format( - course_title=course_title, - exception_message=exception_message + self._log_ingestion_error( + CSVIngestionErrors.COURSE_RUN_UPDATE_ERROR, + CSVIngestionErrorMessages.COURSE_RUN_UPDATE_ERROR.format( + course_title=course_title, exception_message=exception_message + ) ) - logger.exception(error_message) - self._register_ingestion_error(CSVIngestionErrors.COURSE_RUN_UPDATE_ERROR, error_message) continue course_run.refresh_from_db() @@ -295,16 +271,9 @@ def ingest(self): # pylint: disable=too-many-statements course_run.save(update_fields=['status'], send_emails=False) self._complete_run_review(row, course_run) - logger.info("Course and course run updated successfully for course key {}".format(course_key)) # lint-amnesty, pylint: disable=logging-format-interpolation + logger.info(f"Course and course run updated successfully for course key {course_key}") - self.course_uuids[str(course.uuid)] = { - "title": course_title, - "price": ( - row.get("verified_price") if row.get("restriction_type", "None") != - CourseRunRestrictionType.CustomB2BEnterprise.value else - self.course_uuids.get(str(course.uuid), {}).get("price", None) - ), - } + self.course_uuids[str(course.uuid)] = {"title": course_title, "price": self._get_course_price(row, course)} self._register_successful_ingestion( str(course.uuid), str(course_run.variant_id), is_course_created, is_course_run_created, @@ -317,6 +286,179 @@ def ingest(self): # pylint: disable=too-many-statements self._render_error_logs() self._render_course_uuids() + self.clear_caches() + + def _get_course_price(self, row, course): + """ + Determine the price of the course based on the row data. + + Args: + row (dict): The data row containing course details. + course: The course instance. + + Returns: + float | None: The course price or None if unavailable. + """ + if row.get("restriction_type", "None") != CourseRunRestrictionType.CustomB2BEnterprise.value: + return row.get("verified_price") + return self.course_uuids.get(str(course.uuid), {}).get("price", None) + + def _get_course_run_restriction(self, row): + return None if row.get('restriction_type', None) == 'None' else row.get('restriction_type', None) + + @staticmethod + @cache + def _validate_organization(org_key): + """ + Helper method to validate the organization key + + Args: + org_key (str): Organization key + + Returns: + bool: True if the organization exists, False otherwise + """ + return Organization.objects.filter(key=org_key).exists() + + def validate_organization(self, org_key, course_title): + """ + Wrapper method to validate the organization key and log an error if the organization does not exist. + + Args: + org_key (str): Organization key + course_title (str): Course title + + Returns: + bool: True if the organization exists, False otherwise + """ + if not self._validate_organization(org_key): + self._log_ingestion_error( + CSVIngestionErrors.MISSING_ORGANIZATION, + CSVIngestionErrorMessages.MISSING_ORGANIZATION.format( + course_title=course_title, org_key=org_key + ) + ) + return False + return True + + @staticmethod + @cache + def get_course_type(course_type_name): + """ + Retrieve a CourseType object, using a cache to avoid redundant queries. + + Args: + course_type_name (str): Course type name + + Returns: + CourseType: CourseType object + """ + try: + return CourseType.objects.get(name=course_type_name) + except CourseType.DoesNotExist: + return None + + @staticmethod + @cache + def get_course_run_type(course_run_type_name): + """ + Retrieve a CourseRunType object, using a cache to avoid redundant queries. + + Args: + course_run_type_name (str): Course run type name + """ + try: + return CourseRunType.objects.get(name=course_run_type_name) + except CourseRunType.DoesNotExist: + return None + + def _validate_and_process_row(self, row, course_title, org_key): + """ + Validate the row data and process the row if it is valid. + + Args: + row (dict): course data row + course_title (str): Course title + org_key (str): Organization key + + Returns: + bool: True if the row is valid, False otherwise + CourseType: CourseType object + CourseRunType: CourseRunType object + """ + if not self.validate_organization(org_key, course_title): + return False, None, None + + def validate_course_and_course_run_types(row, course_title): + """ + Helper method to validate course and course run types. + + Args: + row (dict): Course data row + course_title (str): Course title + + Returns: + bool: True if course and course run types are valid, False otherwise + CourseType: CourseType object + CourseRunType: CourseRunType object + """ + course_type = self.get_course_type(row["course_enrollment_track"]) + if not course_type: + self._log_ingestion_error( + CSVIngestionErrors.MISSING_COURSE_TYPE, + CSVIngestionErrorMessages.MISSING_COURSE_TYPE.format( + course_title=course_title, course_type=row["course_enrollment_track"] + ), + ) + return False, None, None + + course_run_type = self.get_course_run_type(row["course_run_enrollment_track"]) + if not course_run_type: + self._log_ingestion_error( + CSVIngestionErrors.MISSING_COURSE_RUN_TYPE, + CSVIngestionErrorMessages.MISSING_COURSE_RUN_TYPE.format( + course_title=course_title, course_run_type=row["course_run_enrollment_track"] + ), + ) + return False, None, None + + return True, course_type, course_run_type + + is_valid, course_type, course_run_type = validate_course_and_course_run_types(row, course_title) + if not is_valid: + return False, course_type, course_run_type + + missing_fields = self.validate_course_data(course_type, row) + if missing_fields: + self._log_ingestion_error( + CSVIngestionErrors.MISSING_REQUIRED_DATA, + CSVIngestionErrorMessages.MISSING_REQUIRED_DATA.format( + course_title=course_title, missing_data=missing_fields + ) + ) + return False, course_type, course_run_type + + return True, course_type, course_run_type + + def _log_ingestion_error(self, error_code, message): + """ + Log the error message and continue the ingestion process. + + Args: + error_code: Error code + message (str): Error message + """ + logger.error(message) + self._register_ingestion_error(error_code, message) + + @classmethod + def clear_caches(cls): + """ + Clears all LRU caches associated with the class. + """ + cls.get_course_type.cache_clear() + cls.get_course_run_type.cache_clear() + cls._validate_organization.cache_clear() def _get_or_create_course_run(self, data, course, course_type, course_run_type_uuid): """ @@ -416,8 +558,7 @@ def _render_course_uuids(self): if self.course_uuids: logger.info("Course UUIDs:") for course_uuid, course_dict in self.course_uuids.items(): - logger.info( - "{}:{}".format(course_uuid, course_dict['title'])) # lint-amnesty, pylint: disable=logging-format-interpolation + logger.info(f"{course_uuid}:{course_dict['title']}") def _register_ingestion_error(self, error_key, error_message): """ @@ -669,7 +810,7 @@ def verify_and_get_language_tags(self, language_str): and return a list of language codes. """ languages_codes_list = [] - languages_list = language_str.split(',') + languages_list = language_str.split(",") for language in languages_list: language = language.strip() language_obj = LanguageTag.objects.filter( @@ -677,9 +818,7 @@ def verify_and_get_language_tags(self, language_str): ).first() if not language_obj: raise Exception( # pylint: disable=broad-exception-raised - 'Language {} from provided string {} is either missing or an invalid ietf language'.format( - language, language_str - ) + f"Language {language} from provided string {language_str} is either missing or an invalid ietf language" # pylint: disable=line-too-long ) languages_codes_list.append(language_obj.code) return languages_codes_list @@ -709,7 +848,7 @@ def _create_course(self, data, course_type, course_run_type_uuid): request_data = self._create_course_api_request_data(data, course_type, course_run_type_uuid) response = self._call_course_api('POST', url, request_data) if response.status_code not in (200, 201): - logger.info("Course creation response: {}".format(response.content)) # lint-amnesty, pylint: disable=logging-format-interpolation + logger.info(f"Course creation response: {response.content}") return response.json() def _update_course_entitlement_price(self, data, course_uuid, course_type, is_draft=False): @@ -751,7 +890,7 @@ def _create_course_run(self, data, course, course_type, course_run_type_uuid, re request_data = self._create_course_run_api_request_data(data, course, course_type, course_run_type_uuid, rerun) response = self._call_course_api('POST', url, request_data) if response.status_code not in (200, 201): - logger.info("Course run creation response: {}".format(response.content)) # lint-amnesty, pylint: disable=logging-format-interpolation + logger.info(f"Course run creation response: {response.content}") return response.json() def _update_course(self, data, course, is_draft): @@ -764,7 +903,7 @@ def _update_course(self, data, course, is_draft): response = self._call_course_api('PATCH', url, request_data) if response.status_code not in (200, 201): - logger.info("Course update response: {}".format(response.content)) # lint-amnesty, pylint: disable=logging-format-interpolation + logger.info(f"Course update response: {response.content}") return response.json() def _update_course_run(self, data, course_run, course_type, is_draft): @@ -776,7 +915,7 @@ def _update_course_run(self, data, course_run, course_type, is_draft): request_data = self._update_course_run_request_data(data, course_run, course_type, is_draft) response = self._call_course_api('PATCH', url, request_data) if response.status_code not in (200, 201): - logger.info("Course run update response: {}".format(response.content)) # lint-amnesty, pylint: disable=logging-format-interpolation + logger.info(f"Course run update response: {response.content}") return response.json() def _complete_run_review(self, data, course_run): @@ -833,9 +972,7 @@ def get_subject_slugs(self, *subjects): sub_obj = Subject.objects.get(translations__name=subject, translations__language_code='en') subject_slugs.append(sub_obj.slug) except Subject.DoesNotExist: - logger.exception("Unable to locate subject {} in the database. Skipping subject association".format( # lint-amnesty, pylint: disable=logging-format-interpolation - subject - )) + logger.exception(f"Unable to locate subject {subject} in the database. Skipping subject association") raise return subject_slugs @@ -857,7 +994,7 @@ def process_collaborators(self, collaborators, course_key): collaborator_obj, created = Collaborator.objects.get_or_create(name=collaborator) collaborator_uuids.append(str(collaborator_obj.uuid)) if created: - logger.info("Collaborator {} created for course {}".format(collaborator, course_key)) # lint-amnesty, pylint: disable=logging-format-interpolation + logger.info(f"Collaborator {collaborator} created for course {course_key}") return collaborator_uuids def process_staff_names(self, staff_names, course_run_key): @@ -884,10 +1021,7 @@ def process_staff_names(self, staff_names, course_run_key): ) staff_uuids.append(str(person.uuid)) if created: - logger.info("Staff with name {} has been created for course run {}".format( # lint-amnesty, pylint: disable=logging-format-interpolation - staff_name, - course_run_key - )) + logger.info(f"Staff with name {staff_name} has been created for course run {course_run_key}") return staff_uuids def process_heading_blurb(self, heading, blurb): @@ -906,15 +1040,12 @@ def process_stats(self, stat1, stat1_text, stat2, stat2_text): """ Return a list of stat/fact dicts if valid input values are provided. """ - stats = [] - stat1_dict = self.process_heading_blurb(stat1, stat1_text) - stat2_dict = self.process_heading_blurb(stat2, stat2_text) - - if stat1_dict: - stats.append(stat1_dict) - if stat2_dict: - stats.append(stat2_dict) - return stats + return [ + stat for stat in [ + self.process_heading_blurb(stat1, stat1_text), + self.process_heading_blurb(stat2, stat2_text), + ] if stat + ] def process_meta_information(self, meta_title, meta_description, meta_keywords): """