diff --git a/quinn/dataframe_validator.py b/quinn/dataframe_validator.py index c4b20a64..8c667803 100644 --- a/quinn/dataframe_validator.py +++ b/quinn/dataframe_validator.py @@ -1,7 +1,7 @@ -from __future__ import annotations +from __future__ import annotations # noqa: I001 import copy -from typing import TYPE_CHECKING +from typing import Any, Callable, TYPE_CHECKING if TYPE_CHECKING: from pyspark.sql import DataFrame @@ -36,14 +36,14 @@ def validate_presence_of_columns(df: DataFrame, required_col_names: list[str]) - error_message = f"The {missing_col_names} columns are not included in the DataFrame with the following columns {all_col_names}" if missing_col_names: raise DataFrameMissingColumnError(error_message) - + def validate_schema( - required_schema: StructType, - ignore_nullable: bool = False, - _df: DataFrame = None -) -> function: + required_schema: StructType, + ignore_nullable: bool = False, + _df: DataFrame = None, +) -> Callable[[Any, Any], Any]: """Function that validate if a given DataFrame has a given StructType as its schema. - Implemented as a decorator factory so can be used both as a standalone function or as + Implemented as a decorator factory so can be used both as a standalone function or as a decorator to another function. :param required_schema: StructType required for the DataFrame @@ -59,10 +59,10 @@ def validate_schema( schema are not included in the DataFrame schema """ - def decorator(func): - def wrapper(*args, **kwargs): - df = func(*args, **kwargs) - _all_struct_fields = copy.deepcopy(df.schema) + def decorator(func: Callable[..., DataFrame]) -> Callable[..., DataFrame]: + def wrapper(*args: object, **kwargs: object) -> DataFrame: + dataframe = func(*args, **kwargs) + _all_struct_fields = copy.deepcopy(dataframe.schema) _required_schema = copy.deepcopy(required_schema) if ignore_nullable: @@ -73,22 +73,22 @@ def wrapper(*args, **kwargs): x.nullable = None missing_struct_fields = [x for x in _required_schema if x not in _all_struct_fields] - error_message = f"The {missing_struct_fields} StructFields are not included in the DataFrame with the following StructFields {_all_struct_fields}" + error_message = f"The {missing_struct_fields} StructFields are not included in the DataFrame with the following StructFields {_all_struct_fields}" # noqa: E501 if missing_struct_fields: raise DataFrameMissingStructFieldError(error_message) - else: - print("Success! DataFrame matches the required schema!") - return df + print("Success! DataFrame matches the required schema!") + + return dataframe return wrapper if _df is None: # This means the function is being used as a decorator return decorator - else: - # This means the function is being called directly with a DataFrame - return decorator(lambda: _df)() + + # This means the function is being called directly with a DataFrame + return decorator(lambda: _df)() def validate_absence_of_columns(df: DataFrame, prohibited_col_names: list[str]) -> None: