diff --git a/quinn/dataframe_validator.py b/quinn/dataframe_validator.py index afd21c5a..386979b7 100644 --- a/quinn/dataframe_validator.py +++ b/quinn/dataframe_validator.py @@ -37,7 +37,28 @@ def validate_presence_of_columns(df: DataFrame, required_col_names: list[str]) - if missing_col_names: raise DataFrameMissingColumnError(error_message) -def validate_schema(required_schema: StructType, ignore_nullable=False, _func=None): +def validate_schema( + required_schema: StructType, + ignore_nullable: bool = False, + _df: DataFrame = None +) -> function: + """Function that validate if a given DataFrame has a given StructType as its schema. + Implemented as a decorator factory so can be used both as a standalone function or as + a decorator to another function. + + :param required_schema: StructType required for the DataFrame + :type required_schema: StructType + :param ignore_nullable: (Optional) A flag for if nullable fields should be + ignored during validation + :type ignore_nullable: bool, optional + :param _df: DataFrame to validate, mandatory when called as a function. Not required + when called as a decorator + :type _df: DataFrame + + :raises DataFrameMissingStructFieldError: if any StructFields from the required + schema are not included in the DataFrame schema + """ + def decorator(func): def wrapper(*args, **kwargs): df = func(*args, **kwargs) @@ -59,47 +80,12 @@ def wrapper(*args, **kwargs): return df return wrapper - if _func is None: + if _df is None: # This means the function is being used as a decorator return decorator else: # This means the function is being called directly with a DataFrame - return decorator(lambda: _func)() - - -def x_validate_schema( - df: DataFrame, - required_schema: StructType, - ignore_nullable: bool = False, -) -> None: - """Function that validate if a given DataFrame has a given StructType as its schema. - - :param df: DataFrame to validate - :type df: DataFrame - :param required_schema: StructType required for the DataFrame - :type required_schema: StructType - :param ignore_nullable: (Optional) A flag for if nullable fields should be - ignored during validation - :type ignore_nullable: bool, optional - - :raises DataFrameMissingStructFieldError: if any StructFields from the required - schema are not included in the DataFrame schema - """ - _all_struct_fields = copy.deepcopy(df.schema) - _required_schema = copy.deepcopy(required_schema) - - if ignore_nullable: - for x in _all_struct_fields: - x.nullable = None - - for x in _required_schema: - x.nullable = None - - missing_struct_fields = [x for x in _required_schema if x not in _all_struct_fields] - error_message = f"The {missing_struct_fields} StructFields are not included in the DataFrame with the following StructFields {_all_struct_fields}" - - if missing_struct_fields: - raise DataFrameMissingStructFieldError(error_message) + return decorator(lambda: _df)() def validate_absence_of_columns(df: DataFrame, prohibited_col_names: list[str]) -> None: