Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add missing JSONFormatter options #292

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 165 additions & 70 deletions custom_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,11 @@ def value_repr(name, value):
return str(value)


class CSVFormatter(CopyFormatter):
file_format = 'csv'
class CompressionMixin:

def compression(self, comp_type):
"""String (constant) that specifies to compresses the unloaded data files using the specified compression algorithm."""
"""String (constant) that specifies to compresses the unloaded data files using the specified compression
algorithm."""
if isinstance(comp_type, string_types):
comp_type = comp_type.lower()
_available_options = ['auto', 'gzip', 'bz2', 'brotli', 'zstd', 'deflate', 'raw_deflate', None]
Expand All @@ -202,40 +202,8 @@ def compression(self, comp_type):
self.options['COMPRESSION'] = comp_type
return self

def _check_delimiter(self, delimiter, delimiter_txt):
"""
Check if a delimiter is either a string of length 1 or an integer. In case of
a string delimiter, take into account that the actual string may be longer,
but still evaluate to a single character (like "\\n" or r"\n"
"""
if isinstance(delimiter, NoneType):
return
if isinstance(delimiter, string_types):
delimiter_processed = delimiter.encode().decode("unicode_escape")
if len(delimiter_processed) == 1:
return
if isinstance(delimiter, int):
return
raise TypeError(
"{} should be a single character, that is either a string, or a number".format(delimiter_txt))

def record_delimiter(self, deli_type):
"""Character that separates records in an unloaded file."""
self._check_delimiter(deli_type, "Record delimiter")
if isinstance(deli_type, int):
self.options['RECORD_DELIMITER'] = hex(deli_type)
else:
self.options['RECORD_DELIMITER'] = deli_type
return self

def field_delimiter(self, deli_type):
"""Character that separates fields in an unloaded file."""
self._check_delimiter(deli_type, "Field delimiter")
if isinstance(deli_type, int):
self.options['FIELD_DELIMITER'] = hex(deli_type)
else:
self.options['FIELD_DELIMITER'] = deli_type
return self
class FileExtensionMixin:

def file_extension(self, ext):
"""String that specifies the extension for files unloaded to a stage. Accepts any extension. The user is
Expand All @@ -245,6 +213,9 @@ def file_extension(self, ext):
self.options['FILE_EXTENSION'] = ext
return self


class FormatsDefinitionMixin:

def date_format(self, dt_frmt):
"""String that defines the format of date values in the unloaded data files."""
if not isinstance(dt_frmt, string_types):
Expand Down Expand Up @@ -277,6 +248,99 @@ def binary_format(self, bin_fmt):
self.options['BINARY_FORMAT'] = bin_fmt
return self


class NullIfMixin:

def null_if(self, null_value):
"""Copying into a table these strings will be replaced by a NULL, while copying out of Snowflake will replace
NULL values with the first string"""
if not isinstance(null_value, Sequence):
raise TypeError('Parameter null_value should be an iterable')
self.options['NULL_IF'] = tuple(null_value)
return self


class ReplaceInvalidCharactersMixin:

def replace_invalid_characters(self, replace_invalid_characters: bool) -> 'JSONFormatter':
"""
Specifies whether to replace invalid UTF-8 characters with the Unicode replacement character.

:param replace_invalid_characters: True to replace invalid UTF-8 characters
:return: the JSONFormatter
"""
if not isinstance(replace_invalid_characters, bool):
raise TypeError("replace_invalid_characters should be a bool")
self.options['REPLACE_INVALID_CHARACTERS'] = replace_invalid_characters
return self


class SkipByteOrderMarkMixin:

def skip_byte_order_mark(self, skip_byte_order_mark: bool) -> 'JSONFormatter':
"""
Specifies whether to skip the BOM (byte order mark), if present in a data file.

:param skip_byte_order_mark: True to skip the BOM if present in data file.
:return: the JSONFormatter
"""
if not isinstance(skip_byte_order_mark, bool):
raise TypeError("skip_byte_order_mark should be a bool")
self.options["SKIP_BYTE_ORDER_MARK"] = skip_byte_order_mark
return self


class TrimSpaceMixin:

def trim_space(self, trim_space):
"""
Remove leading or trailing white spaces
"""
if not isinstance(trim_space, bool):
raise TypeError("trim_space should be a bool")
self.options['TRIM_SPACE'] = trim_space
return self


class CSVFormatter(CopyFormatter, CompressionMixin, FileExtensionMixin, FormatsDefinitionMixin, NullIfMixin,
ReplaceInvalidCharactersMixin, SkipByteOrderMarkMixin, TrimSpaceMixin):
file_format = 'csv'

def _check_delimiter(self, delimiter, delimiter_txt):
"""
Check if a delimiter is either a string of length 1 or an integer. In case of
a string delimiter, take into account that the actual string may be longer,
but still evaluate to a single character (like "\\n" or r"\n"
"""
if isinstance(delimiter, NoneType):
return
if isinstance(delimiter, string_types):
delimiter_processed = delimiter.encode().decode("unicode_escape")
if len(delimiter_processed) == 1:
return
if isinstance(delimiter, int):
return
raise TypeError(
"{} should be a single character, that is either a string, or a number".format(delimiter_txt))

def record_delimiter(self, deli_type):
"""Character that separates records in an unloaded file."""
self._check_delimiter(deli_type, "Record delimiter")
if isinstance(deli_type, int):
self.options['RECORD_DELIMITER'] = hex(deli_type)
else:
self.options['RECORD_DELIMITER'] = deli_type
return self

def field_delimiter(self, deli_type):
"""Character that separates fields in an unloaded file."""
self._check_delimiter(deli_type, "Field delimiter")
if isinstance(deli_type, int):
self.options['FIELD_DELIMITER'] = hex(deli_type)
else:
self.options['FIELD_DELIMITER'] = deli_type
return self

def escape(self, esc):
"""Character used as the escape character for any field values."""
self._check_delimiter(esc, "Escape")
Expand All @@ -303,14 +367,6 @@ def field_optionally_enclosed_by(self, enc):
self.options['FIELD_OPTIONALLY_ENCLOSED_BY'] = enc
return self

def null_if(self, null_value):
"""Copying into a table these strings will be replaced by a NULL, while copying out of Snowflake will replace
NULL values with the first string"""
if not isinstance(null_value, Sequence):
raise TypeError('Parameter null_value should be an iterable')
self.options['NULL_IF'] = tuple(null_value)
return self

def skip_header(self, skip_header):
"""
Number of header rows to be skipped at the beginning of the file
Expand All @@ -320,15 +376,6 @@ def skip_header(self, skip_header):
self.options['SKIP_HEADER'] = skip_header
return self

def trim_space(self, trim_space):
"""
Remove leading or trailing white spaces
"""
if not isinstance(trim_space, bool):
raise TypeError("trim_space should be a bool")
self.options['TRIM_SPACE'] = trim_space
return self

def error_on_column_count_mismatch(self, error_on_col_count_mismatch):
"""
Generate a parsing error if the number of delimited columns (i.e. fields) in
Expand All @@ -340,31 +387,79 @@ def error_on_column_count_mismatch(self, error_on_col_count_mismatch):
return self


class JSONFormatter(CopyFormatter):
"""Format specific functions"""
class JSONFormatter(CopyFormatter, CompressionMixin, FileExtensionMixin, FormatsDefinitionMixin,
NullIfMixin, ReplaceInvalidCharactersMixin, SkipByteOrderMarkMixin, TrimSpaceMixin):
"""JSON format specific functions"""

file_format = 'json'

def compression(self, comp_type):
"""String (constant) that specifies to compresses the unloaded data files using the specified compression algorithm."""
if isinstance(comp_type, string_types):
comp_type = comp_type.lower()
_available_options = ['auto', 'gzip', 'bz2', 'brotli', 'zstd', 'deflate', 'raw_deflate', None]
if comp_type not in _available_options:
raise TypeError("Compression type should be one of : {}".format(_available_options))
self.options['COMPRESSION'] = comp_type
def null_if(self, null_value: Sequence) -> 'JSONFormatter':
"""Copying into a table these strings will be replaced by a NULL, while copying out of Snowflake will replace
NULL values with the first string"""
if not isinstance(null_value, Sequence):
raise TypeError('Parameter null_value should be an iterable')
self.options['NULL_IF'] = tuple(null_value)
return self

def file_extension(self, ext):
"""String that specifies the extension for files unloaded to a stage. Accepts any extension. The user is
responsible for specifying a valid file extension that can be read by the desired software or service. """
if not isinstance(ext, (NoneType, string_types)):
raise TypeError("File extension should be a string")
self.options['FILE_EXTENSION'] = ext
def enable_octal(self, enable_octal: bool) -> 'JSONFormatter':
"""
Enables parsing of octal numbers.
:param enable_octal:
:return: the JSONFormatter
"""
if not isinstance(enable_octal, bool):
raise TypeError("enable_octal should be a bool")
self.options['ENABLE_OCTAL'] = enable_octal
return self

def allow_duplicate(self, allow_duplicate: bool) -> 'JSONFormatter':
"""
Allows duplicate object filed names (only the last one is preserved)
:param allow_duplicate: True to allow duplicate fields.
:return: the JSONFormatter.
"""
if not isinstance(allow_duplicate, bool):
raise TypeError("enable_duplicate should be a bool")
self.options['ALLOW_DUPLICATE'] = allow_duplicate
return self

def strip_outer_array(self, strip_outer_array: bool) -> 'JSONFormatter':
"""
Instructs the JSON parser to remove outer brackets
:param strip_outer_array: True to strip outer brackets.
:return: the JSONFormatter
"""
if not isinstance(strip_outer_array, bool):
raise TypeError("strip_outer_array should be a bool")
self.options['STRIP_OUTER_ARRAY'] = strip_outer_array
return self

def strip_null_values(self, strip_null_values: bool) -> 'JSONFormatter':
"""
Instructs the JSON parser to remove object fields or array elements containing null values.

:param strip_null_values: Tru to strip null fields and array elements
:return: the JSONFormatter
"""
if not isinstance(strip_null_values, bool):
raise TypeError("strip_null_values should be a bool")
self.options['STRIP_NULL_VALUES'] = strip_null_values
return self

def ignore_utf8_errors(self, ignore_utf8_errors: bool) -> 'JSONFormatter':
"""
Specifies whether UTF-8 encoding errors produce error conditions. If set to True, any invalid UTF-8 sequences
are silently replaced with the Unicode replacement character.
:param ignore_utf8_errors: True to ignore utf8 invalid sequences
:return: the JSONFormatter
"""
if not isinstance(ignore_utf8_errors, bool):
raise TypeError("ignore_utf8_errors should be a bool")
self.options["IGNORE_UTF8_ERRORS"] = ignore_utf8_errors
return self


class PARQUETFormatter(CopyFormatter):
class PARQUETFormatter(CopyFormatter, NullIfMixin, TrimSpaceMixin):
"""Format specific functions"""

file_format = 'parquet'
Expand Down