From e4ec0bb7cdf88957be1ccbb73ee6841f93092ae6 Mon Sep 17 00:00:00 2001 From: Stefan Krawczyk Date: Mon, 6 Nov 2023 17:49:52 -0800 Subject: [PATCH] Fixes default data quality pd series type check This was not handling the new pandas types. This seems to fix that, while also updating the example to no log warnings. --- examples/data_quality/simple/feature_logic.py | 18 +++++++++--------- hamilton/data_quality/default_validators.py | 2 ++ tests/test_default_data_quality.py | 6 ++++++ 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/examples/data_quality/simple/feature_logic.py b/examples/data_quality/simple/feature_logic.py index fd108eccc..9f3e92b57 100644 --- a/examples/data_quality/simple/feature_logic.py +++ b/examples/data_quality/simple/feature_logic.py @@ -82,25 +82,25 @@ def seasons_encoded__dask(seasons: pd.Series) -> pd.DataFrame: return df -@check_output(data_type=np.uint8, values_in=[0, 1], allow_nans=False) +@check_output(data_type=np.bool_, values_in=[0, 1], allow_nans=False) def seasons_1(seasons_encoded: pd.DataFrame) -> pd.Series: """Returns column seasons_1""" return seasons_encoded["seasons_1"] -@check_output(data_type=np.uint8, values_in=[0, 1], allow_nans=False) +@check_output(data_type=np.bool_, values_in=[0, 1], allow_nans=False) def seasons_2(seasons_encoded: pd.DataFrame) -> pd.Series: """Returns column seasons_2""" return seasons_encoded["seasons_2"] -@check_output(data_type=np.uint8, values_in=[0, 1], allow_nans=False) +@check_output(data_type=np.bool_, values_in=[0, 1], allow_nans=False) def seasons_3(seasons_encoded: pd.DataFrame) -> pd.Series: """Returns column seasons_3""" return seasons_encoded["seasons_3"] -@check_output(data_type=np.uint8, values_in=[0, 1], allow_nans=False) +@check_output(data_type=np.bool_, values_in=[0, 1], allow_nans=False) def seasons_4(seasons_encoded: pd.DataFrame) -> pd.Series: """Returns column seasons_4""" return seasons_encoded["seasons_4"] @@ -136,31 +136,31 @@ def day_of_week_encoded__dask(day_of_the_week: pd.Series) -> pd.DataFrame: return df -@check_output(data_type=np.uint8, values_in=[0, 1], allow_nans=False) +@check_output(data_type=np.bool_, values_in=[0, 1], allow_nans=False) def day_of_the_week_2(day_of_week_encoded: pd.DataFrame) -> pd.Series: """Pulls out the day_of_the_week_2 column.""" return day_of_week_encoded["day_of_the_week_2"] -@check_output(data_type=np.uint8, values_in=[0, 1], allow_nans=False) +@check_output(data_type=np.bool_, values_in=[0, 1], allow_nans=False) def day_of_the_week_3(day_of_week_encoded: pd.DataFrame) -> pd.Series: """Pulls out the day_of_the_week_3 column.""" return day_of_week_encoded["day_of_the_week_3"] -@check_output(data_type=np.uint8, values_in=[0, 1], allow_nans=False) +@check_output(data_type=np.bool_, values_in=[0, 1], allow_nans=False) def day_of_the_week_4(day_of_week_encoded: pd.DataFrame) -> pd.Series: """Pulls out the day_of_the_week_4 column.""" return day_of_week_encoded["day_of_the_week_4"] -@check_output(data_type=np.uint8, values_in=[0, 1], allow_nans=False) +@check_output(data_type=np.bool_, values_in=[0, 1], allow_nans=False) def day_of_the_week_5(day_of_week_encoded: pd.DataFrame) -> pd.Series: """Pulls out the day_of_the_week_5 column.""" return day_of_week_encoded["day_of_the_week_5"] -@check_output(data_type=np.uint8, values_in=[0, 1], allow_nans=False) +@check_output(data_type=np.bool_, values_in=[0, 1], allow_nans=False) def day_of_the_week_6(day_of_week_encoded: pd.DataFrame) -> pd.Series: """Pulls out the day_of_the_week_6 column.""" return day_of_week_encoded["day_of_the_week_6"] diff --git a/hamilton/data_quality/default_validators.py b/hamilton/data_quality/default_validators.py index 23196a0d9..f77cf2cef 100644 --- a/hamilton/data_quality/default_validators.py +++ b/hamilton/data_quality/default_validators.py @@ -259,6 +259,8 @@ def description(self) -> str: def validate(self, data: pd.Series) -> base.ValidationResult: dtype = data.dtype + if hasattr(dtype, "type"): + dtype = dtype.type passes = np.issubdtype(dtype, self.datatype) return base.ValidationResult( passes=passes, diff --git a/tests/test_default_data_quality.py b/tests/test_default_data_quality.py index 981a293a7..3df2414a3 100644 --- a/tests/test_default_data_quality.py +++ b/tests/test_default_data_quality.py @@ -171,6 +171,12 @@ def test_resolve_default_validators_error(output_type, kwargs, importance): pd.Series(["hello", "goodbye"]), True, ), + ( + default_validators.DataTypeValidatorPandasSeries, + np.float64, + pd.Series([2.3, 4.5, 6.6], dtype=pd.Float64Dtype()), + True, + ), ( default_validators.DataTypeValidatorPandasSeries, numpy.dtype("object"),