diff --git a/custom_commands.py b/custom_commands.py index 2b2f510c..a4cc6378 100644 --- a/custom_commands.py +++ b/custom_commands.py @@ -189,11 +189,11 @@ def value_repr(name, value): return str(value) -class CSVFormatter(CopyFormatter): - file_format = 'csv' +class CompressionMixin: def compression(self, comp_type): - """String (constant) that specifies to compresses the unloaded data files using the specified compression algorithm.""" + """String (constant) that specifies to compresses the unloaded data files using the specified compression + algorithm.""" if isinstance(comp_type, string_types): comp_type = comp_type.lower() _available_options = ['auto', 'gzip', 'bz2', 'brotli', 'zstd', 'deflate', 'raw_deflate', None] @@ -202,40 +202,8 @@ def compression(self, comp_type): self.options['COMPRESSION'] = comp_type return self - def _check_delimiter(self, delimiter, delimiter_txt): - """ - Check if a delimiter is either a string of length 1 or an integer. In case of - a string delimiter, take into account that the actual string may be longer, - but still evaluate to a single character (like "\\n" or r"\n" - """ - if isinstance(delimiter, NoneType): - return - if isinstance(delimiter, string_types): - delimiter_processed = delimiter.encode().decode("unicode_escape") - if len(delimiter_processed) == 1: - return - if isinstance(delimiter, int): - return - raise TypeError( - "{} should be a single character, that is either a string, or a number".format(delimiter_txt)) - def record_delimiter(self, deli_type): - """Character that separates records in an unloaded file.""" - self._check_delimiter(deli_type, "Record delimiter") - if isinstance(deli_type, int): - self.options['RECORD_DELIMITER'] = hex(deli_type) - else: - self.options['RECORD_DELIMITER'] = deli_type - return self - - def field_delimiter(self, deli_type): - """Character that separates fields in an unloaded file.""" - self._check_delimiter(deli_type, "Field delimiter") - if isinstance(deli_type, int): - self.options['FIELD_DELIMITER'] = hex(deli_type) - else: - self.options['FIELD_DELIMITER'] = deli_type - return self +class FileExtensionMixin: def file_extension(self, ext): """String that specifies the extension for files unloaded to a stage. Accepts any extension. The user is @@ -245,6 +213,9 @@ def file_extension(self, ext): self.options['FILE_EXTENSION'] = ext return self + +class FormatsDefinitionMixin: + def date_format(self, dt_frmt): """String that defines the format of date values in the unloaded data files.""" if not isinstance(dt_frmt, string_types): @@ -277,6 +248,99 @@ def binary_format(self, bin_fmt): self.options['BINARY_FORMAT'] = bin_fmt return self + +class NullIfMixin: + + def null_if(self, null_value): + """Copying into a table these strings will be replaced by a NULL, while copying out of Snowflake will replace + NULL values with the first string""" + if not isinstance(null_value, Sequence): + raise TypeError('Parameter null_value should be an iterable') + self.options['NULL_IF'] = tuple(null_value) + return self + + +class ReplaceInvalidCharactersMixin: + + def replace_invalid_characters(self, replace_invalid_characters: bool) -> 'JSONFormatter': + """ + Specifies whether to replace invalid UTF-8 characters with the Unicode replacement character. + + :param replace_invalid_characters: True to replace invalid UTF-8 characters + :return: the JSONFormatter + """ + if not isinstance(replace_invalid_characters, bool): + raise TypeError("replace_invalid_characters should be a bool") + self.options['REPLACE_INVALID_CHARACTERS'] = replace_invalid_characters + return self + + +class SkipByteOrderMarkMixin: + + def skip_byte_order_mark(self, skip_byte_order_mark: bool) -> 'JSONFormatter': + """ + Specifies whether to skip the BOM (byte order mark), if present in a data file. + + :param skip_byte_order_mark: True to skip the BOM if present in data file. + :return: the JSONFormatter + """ + if not isinstance(skip_byte_order_mark, bool): + raise TypeError("skip_byte_order_mark should be a bool") + self.options["SKIP_BYTE_ORDER_MARK"] = skip_byte_order_mark + return self + + +class TrimSpaceMixin: + + def trim_space(self, trim_space): + """ + Remove leading or trailing white spaces + """ + if not isinstance(trim_space, bool): + raise TypeError("trim_space should be a bool") + self.options['TRIM_SPACE'] = trim_space + return self + + +class CSVFormatter(CopyFormatter, CompressionMixin, FileExtensionMixin, FormatsDefinitionMixin, NullIfMixin, + ReplaceInvalidCharactersMixin, SkipByteOrderMarkMixin, TrimSpaceMixin): + file_format = 'csv' + + def _check_delimiter(self, delimiter, delimiter_txt): + """ + Check if a delimiter is either a string of length 1 or an integer. In case of + a string delimiter, take into account that the actual string may be longer, + but still evaluate to a single character (like "\\n" or r"\n" + """ + if isinstance(delimiter, NoneType): + return + if isinstance(delimiter, string_types): + delimiter_processed = delimiter.encode().decode("unicode_escape") + if len(delimiter_processed) == 1: + return + if isinstance(delimiter, int): + return + raise TypeError( + "{} should be a single character, that is either a string, or a number".format(delimiter_txt)) + + def record_delimiter(self, deli_type): + """Character that separates records in an unloaded file.""" + self._check_delimiter(deli_type, "Record delimiter") + if isinstance(deli_type, int): + self.options['RECORD_DELIMITER'] = hex(deli_type) + else: + self.options['RECORD_DELIMITER'] = deli_type + return self + + def field_delimiter(self, deli_type): + """Character that separates fields in an unloaded file.""" + self._check_delimiter(deli_type, "Field delimiter") + if isinstance(deli_type, int): + self.options['FIELD_DELIMITER'] = hex(deli_type) + else: + self.options['FIELD_DELIMITER'] = deli_type + return self + def escape(self, esc): """Character used as the escape character for any field values.""" self._check_delimiter(esc, "Escape") @@ -303,14 +367,6 @@ def field_optionally_enclosed_by(self, enc): self.options['FIELD_OPTIONALLY_ENCLOSED_BY'] = enc return self - def null_if(self, null_value): - """Copying into a table these strings will be replaced by a NULL, while copying out of Snowflake will replace - NULL values with the first string""" - if not isinstance(null_value, Sequence): - raise TypeError('Parameter null_value should be an iterable') - self.options['NULL_IF'] = tuple(null_value) - return self - def skip_header(self, skip_header): """ Number of header rows to be skipped at the beginning of the file @@ -320,15 +376,6 @@ def skip_header(self, skip_header): self.options['SKIP_HEADER'] = skip_header return self - def trim_space(self, trim_space): - """ - Remove leading or trailing white spaces - """ - if not isinstance(trim_space, bool): - raise TypeError("trim_space should be a bool") - self.options['TRIM_SPACE'] = trim_space - return self - def error_on_column_count_mismatch(self, error_on_col_count_mismatch): """ Generate a parsing error if the number of delimited columns (i.e. fields) in @@ -340,31 +387,79 @@ def error_on_column_count_mismatch(self, error_on_col_count_mismatch): return self -class JSONFormatter(CopyFormatter): - """Format specific functions""" +class JSONFormatter(CopyFormatter, CompressionMixin, FileExtensionMixin, FormatsDefinitionMixin, + NullIfMixin, ReplaceInvalidCharactersMixin, SkipByteOrderMarkMixin, TrimSpaceMixin): + """JSON format specific functions""" file_format = 'json' - def compression(self, comp_type): - """String (constant) that specifies to compresses the unloaded data files using the specified compression algorithm.""" - if isinstance(comp_type, string_types): - comp_type = comp_type.lower() - _available_options = ['auto', 'gzip', 'bz2', 'brotli', 'zstd', 'deflate', 'raw_deflate', None] - if comp_type not in _available_options: - raise TypeError("Compression type should be one of : {}".format(_available_options)) - self.options['COMPRESSION'] = comp_type + def null_if(self, null_value: Sequence) -> 'JSONFormatter': + """Copying into a table these strings will be replaced by a NULL, while copying out of Snowflake will replace + NULL values with the first string""" + if not isinstance(null_value, Sequence): + raise TypeError('Parameter null_value should be an iterable') + self.options['NULL_IF'] = tuple(null_value) return self - def file_extension(self, ext): - """String that specifies the extension for files unloaded to a stage. Accepts any extension. The user is - responsible for specifying a valid file extension that can be read by the desired software or service. """ - if not isinstance(ext, (NoneType, string_types)): - raise TypeError("File extension should be a string") - self.options['FILE_EXTENSION'] = ext + def enable_octal(self, enable_octal: bool) -> 'JSONFormatter': + """ + Enables parsing of octal numbers. + :param enable_octal: + :return: the JSONFormatter + """ + if not isinstance(enable_octal, bool): + raise TypeError("enable_octal should be a bool") + self.options['ENABLE_OCTAL'] = enable_octal + return self + + def allow_duplicate(self, allow_duplicate: bool) -> 'JSONFormatter': + """ + Allows duplicate object filed names (only the last one is preserved) + :param allow_duplicate: True to allow duplicate fields. + :return: the JSONFormatter. + """ + if not isinstance(allow_duplicate, bool): + raise TypeError("enable_duplicate should be a bool") + self.options['ALLOW_DUPLICATE'] = allow_duplicate + return self + + def strip_outer_array(self, strip_outer_array: bool) -> 'JSONFormatter': + """ + Instructs the JSON parser to remove outer brackets + :param strip_outer_array: True to strip outer brackets. + :return: the JSONFormatter + """ + if not isinstance(strip_outer_array, bool): + raise TypeError("strip_outer_array should be a bool") + self.options['STRIP_OUTER_ARRAY'] = strip_outer_array + return self + + def strip_null_values(self, strip_null_values: bool) -> 'JSONFormatter': + """ + Instructs the JSON parser to remove object fields or array elements containing null values. + + :param strip_null_values: Tru to strip null fields and array elements + :return: the JSONFormatter + """ + if not isinstance(strip_null_values, bool): + raise TypeError("strip_null_values should be a bool") + self.options['STRIP_NULL_VALUES'] = strip_null_values + return self + + def ignore_utf8_errors(self, ignore_utf8_errors: bool) -> 'JSONFormatter': + """ + Specifies whether UTF-8 encoding errors produce error conditions. If set to True, any invalid UTF-8 sequences + are silently replaced with the Unicode replacement character. + :param ignore_utf8_errors: True to ignore utf8 invalid sequences + :return: the JSONFormatter + """ + if not isinstance(ignore_utf8_errors, bool): + raise TypeError("ignore_utf8_errors should be a bool") + self.options["IGNORE_UTF8_ERRORS"] = ignore_utf8_errors return self -class PARQUETFormatter(CopyFormatter): +class PARQUETFormatter(CopyFormatter, NullIfMixin, TrimSpaceMixin): """Format specific functions""" file_format = 'parquet'