Skip to content

Commit

Permalink
Static type error fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
kunaljubce committed Jul 16, 2024
1 parent b9a6d08 commit 21d87e5
Showing 1 changed file with 19 additions and 19 deletions.
38 changes: 19 additions & 19 deletions quinn/dataframe_validator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations
from __future__ import annotations # noqa: I001

import copy
from typing import TYPE_CHECKING
from typing import Any, Callable, TYPE_CHECKING

if TYPE_CHECKING:
from pyspark.sql import DataFrame
Expand Down Expand Up @@ -36,14 +36,14 @@ def validate_presence_of_columns(df: DataFrame, required_col_names: list[str]) -
error_message = f"The {missing_col_names} columns are not included in the DataFrame with the following columns {all_col_names}"
if missing_col_names:
raise DataFrameMissingColumnError(error_message)

def validate_schema(
required_schema: StructType,
ignore_nullable: bool = False,
_df: DataFrame = None
) -> function:
required_schema: StructType,
ignore_nullable: bool = False,
_df: DataFrame = None,
) -> Callable[[Any, Any], Any]:
"""Function that validate if a given DataFrame has a given StructType as its schema.
Implemented as a decorator factory so can be used both as a standalone function or as
Implemented as a decorator factory so can be used both as a standalone function or as
a decorator to another function.
:param required_schema: StructType required for the DataFrame
Expand All @@ -59,10 +59,10 @@ def validate_schema(
schema are not included in the DataFrame schema
"""

def decorator(func):
def wrapper(*args, **kwargs):
df = func(*args, **kwargs)
_all_struct_fields = copy.deepcopy(df.schema)
def decorator(func: Callable[..., DataFrame]) -> Callable[..., DataFrame]:
def wrapper(*args: object, **kwargs: object) -> DataFrame:
dataframe = func(*args, **kwargs)
_all_struct_fields = copy.deepcopy(dataframe.schema)
_required_schema = copy.deepcopy(required_schema)

if ignore_nullable:
Expand All @@ -73,22 +73,22 @@ def wrapper(*args, **kwargs):
x.nullable = None

missing_struct_fields = [x for x in _required_schema if x not in _all_struct_fields]
error_message = f"The {missing_struct_fields} StructFields are not included in the DataFrame with the following StructFields {_all_struct_fields}"
error_message = f"The {missing_struct_fields} StructFields are not included in the DataFrame with the following StructFields {_all_struct_fields}" # noqa: E501

if missing_struct_fields:
raise DataFrameMissingStructFieldError(error_message)
else:
print("Success! DataFrame matches the required schema!")

return df
print("Success! DataFrame matches the required schema!")

return dataframe
return wrapper

if _df is None:
# This means the function is being used as a decorator
return decorator
else:
# This means the function is being called directly with a DataFrame
return decorator(lambda: _df)()

# This means the function is being called directly with a DataFrame
return decorator(lambda: _df)()


def validate_absence_of_columns(df: DataFrame, prohibited_col_names: list[str]) -> None:
Expand Down

0 comments on commit 21d87e5

Please sign in to comment.