diff --git a/pandas_schema/validation.py b/pandas_schema/validation.py index 2334bb6..2a3f2f8 100644 --- a/pandas_schema/validation.py +++ b/pandas_schema/validation.py @@ -9,6 +9,7 @@ from . import column from .validation_warning import ValidationWarning from .errors import PanSchArgumentError +from pandas.api.types import is_categorical_dtype, is_numeric_dtype class _BaseValidation: @@ -84,10 +85,12 @@ def get_errors(self, series: pd.Series, column: 'column.Column'): simple_validation = ~self.validate(series) if column.allow_empty: # Failing results are those that are not empty, and fail the validation - if np.issubdtype(series.dtype, np.number): - validated = ~series.isna() & simple_validation + # explicitly check to make sure the series isn't a category because issubdtype will FAIL if it is + if is_categorical_dtype(series) or is_numeric_dtype(series): + validated = ~series.isnull() & simple_validation else: validated = (series.str.len() > 0) & simple_validation + else: validated = simple_validation diff --git a/setup.py b/setup.py index bf2cbc0..0f68cdf 100755 --- a/setup.py +++ b/setup.py @@ -78,7 +78,7 @@ def run(self): ], keywords='pandas csv verification schema', packages=find_packages(include=['pandas_schema']), - install_requires=['numpy', 'pandas'], + install_requires=['numpy', 'pandas>=0.19'], cmdclass={ 'build_readme': BuildReadme, 'build_site': BuildHtmlDocs diff --git a/test/test_validation.py b/test/test_validation.py index 3e8f2bb..7914025 100644 --- a/test/test_validation.py +++ b/test/test_validation.py @@ -15,9 +15,10 @@ def seriesEquality(self, s1: pd.Series, s2: pd.Series, msg: str = None): if not s1.equals(s2): raise self.failureException(msg) - def validate_and_compare(self, series: list, expected_result: bool, msg: str = None): + def validate_and_compare(self, series: list, expected_result: bool, msg: str = None, series_dtype: object = None): """ Checks that every element in the provided series is equal to `expected_result` after validation + :param series_dtype: Explicity specifies the dtype for the generated Series :param series: The series to check :param expected_result: Whether the elements in this series should pass the validation :param msg: The message to display if this test fails @@ -31,7 +32,7 @@ def validate_and_compare(self, series: list, expected_result: bool, msg: str = N self.addTypeEqualityFunc(pd.Series, self.seriesEquality) # Convert the input list to a series and validate it - results = self.validator.validate(pd.Series(series)) + results = self.validator.validate(pd.Series(series, dtype=series_dtype)) # Now find any items where their validation does not correspond to the expected_result for item, result in zip(series, results): @@ -639,3 +640,32 @@ def test_in_range_allow_empty_false_with_error(self): validator = InRangeValidation(min=4) errors = validator.get_errors(pd.Series(self.vals), Column('', allow_empty=False)) self.assertEqual(len(errors), len(self.vals)) + + +class PandasDtypeTests(ValidationTestBase): + """ + Tests Series with various pandas dtypes that don't exist in numpy (specifically categories) + """ + + def setUp(self): + self.validator = InListValidation(['a', 'b', 'c'], case_sensitive=False) + + def test_valid_elements(self): + errors = self.validator.get_errors(pd.Series(['a', 'b', 'c', None, 'A', 'B', 'C'], dtype='category'), + Column('', allow_empty=True)) + self.assertEqual(len(errors), 0) + + def test_invalid_empty_elements(self): + errors = self.validator.get_errors(pd.Series(['aa', 'bb', 'd', None], dtype='category'), + Column('', allow_empty=False)) + self.assertEqual(len(errors), 4) + + def test_invalid_and_empty_elements(self): + errors = self.validator.get_errors(pd.Series(['a', None], dtype='category'), + Column('', allow_empty=False)) + self.assertEqual(len(errors), 1) + + def test_invalid_elements(self): + errors = self.validator.get_errors(pd.Series(['aa', 'bb', 'd'], dtype='category'), + Column('', allow_empty=True)) + self.assertEqual(len(errors), 3)