Skip to content

Commit

Permalink
Fixes default data quality pd series type check
Browse files Browse the repository at this point in the history
This was not handling the new pandas types.
This seems to fix that, while also updating the example
to no log warnings.
  • Loading branch information
skrawcz committed Nov 7, 2023
1 parent 4d981fa commit e4ec0bb
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 9 deletions.
18 changes: 9 additions & 9 deletions examples/data_quality/simple/feature_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,25 +82,25 @@ def seasons_encoded__dask(seasons: pd.Series) -> pd.DataFrame:
return df


@check_output(data_type=np.uint8, values_in=[0, 1], allow_nans=False)
@check_output(data_type=np.bool_, values_in=[0, 1], allow_nans=False)
def seasons_1(seasons_encoded: pd.DataFrame) -> pd.Series:
"""Returns column seasons_1"""
return seasons_encoded["seasons_1"]


@check_output(data_type=np.uint8, values_in=[0, 1], allow_nans=False)
@check_output(data_type=np.bool_, values_in=[0, 1], allow_nans=False)
def seasons_2(seasons_encoded: pd.DataFrame) -> pd.Series:
"""Returns column seasons_2"""
return seasons_encoded["seasons_2"]


@check_output(data_type=np.uint8, values_in=[0, 1], allow_nans=False)
@check_output(data_type=np.bool_, values_in=[0, 1], allow_nans=False)
def seasons_3(seasons_encoded: pd.DataFrame) -> pd.Series:
"""Returns column seasons_3"""
return seasons_encoded["seasons_3"]


@check_output(data_type=np.uint8, values_in=[0, 1], allow_nans=False)
@check_output(data_type=np.bool_, values_in=[0, 1], allow_nans=False)
def seasons_4(seasons_encoded: pd.DataFrame) -> pd.Series:
"""Returns column seasons_4"""
return seasons_encoded["seasons_4"]
Expand Down Expand Up @@ -136,31 +136,31 @@ def day_of_week_encoded__dask(day_of_the_week: pd.Series) -> pd.DataFrame:
return df


@check_output(data_type=np.uint8, values_in=[0, 1], allow_nans=False)
@check_output(data_type=np.bool_, values_in=[0, 1], allow_nans=False)
def day_of_the_week_2(day_of_week_encoded: pd.DataFrame) -> pd.Series:
"""Pulls out the day_of_the_week_2 column."""
return day_of_week_encoded["day_of_the_week_2"]


@check_output(data_type=np.uint8, values_in=[0, 1], allow_nans=False)
@check_output(data_type=np.bool_, values_in=[0, 1], allow_nans=False)
def day_of_the_week_3(day_of_week_encoded: pd.DataFrame) -> pd.Series:
"""Pulls out the day_of_the_week_3 column."""
return day_of_week_encoded["day_of_the_week_3"]


@check_output(data_type=np.uint8, values_in=[0, 1], allow_nans=False)
@check_output(data_type=np.bool_, values_in=[0, 1], allow_nans=False)
def day_of_the_week_4(day_of_week_encoded: pd.DataFrame) -> pd.Series:
"""Pulls out the day_of_the_week_4 column."""
return day_of_week_encoded["day_of_the_week_4"]


@check_output(data_type=np.uint8, values_in=[0, 1], allow_nans=False)
@check_output(data_type=np.bool_, values_in=[0, 1], allow_nans=False)
def day_of_the_week_5(day_of_week_encoded: pd.DataFrame) -> pd.Series:
"""Pulls out the day_of_the_week_5 column."""
return day_of_week_encoded["day_of_the_week_5"]


@check_output(data_type=np.uint8, values_in=[0, 1], allow_nans=False)
@check_output(data_type=np.bool_, values_in=[0, 1], allow_nans=False)
def day_of_the_week_6(day_of_week_encoded: pd.DataFrame) -> pd.Series:
"""Pulls out the day_of_the_week_6 column."""
return day_of_week_encoded["day_of_the_week_6"]
Expand Down
2 changes: 2 additions & 0 deletions hamilton/data_quality/default_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,8 @@ def description(self) -> str:

def validate(self, data: pd.Series) -> base.ValidationResult:
dtype = data.dtype
if hasattr(dtype, "type"):
dtype = dtype.type
passes = np.issubdtype(dtype, self.datatype)
return base.ValidationResult(
passes=passes,
Expand Down
6 changes: 6 additions & 0 deletions tests/test_default_data_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,12 @@ def test_resolve_default_validators_error(output_type, kwargs, importance):
pd.Series(["hello", "goodbye"]),
True,
),
(
default_validators.DataTypeValidatorPandasSeries,
np.float64,
pd.Series([2.3, 4.5, 6.6], dtype=pd.Float64Dtype()),
True,
),
(
default_validators.DataTypeValidatorPandasSeries,
numpy.dtype("object"),
Expand Down

0 comments on commit e4ec0bb

Please sign in to comment.