Skip to content

Commit

Permalink
Fixes pandas string data validation
Browse files Browse the repository at this point in the history
  • Loading branch information
erp12 committed Jul 20, 2020
1 parent d61d2df commit 24d44b0
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 22 deletions.
2 changes: 1 addition & 1 deletion pyshgp/push/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __init__(self):
class PushStrType(PushType):

def __init__(self):
super().__init__("str", (str, np.str_))
super().__init__("str", (str, np.str_, np.object_))


class PushVectorType(PushType):
Expand Down
38 changes: 17 additions & 21 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,31 +40,27 @@ def test_check_2d_on_3d():
check_2d(np.arange(12).reshape(-1, 2, 2))


def test_check_column_types_on_int_ndarray():
assert (
check_column_types(np.arange(30).reshape(-1, 3)) == [np.int64, np.int64, np.int64] or
check_column_types(np.arange(30).reshape(-1, 3)) == [np.int32, np.int32, np.int32] or
check_column_types(np.arange(30).reshape(-1, 3)) == [np.int16, np.int16, np.int16]
)
def test_check_column_types():
arr_col_types = check_column_types(np.arange(30).reshape(-1, 3))
assert arr_col_types == [np.int64, np.int64, np.int64] or \
arr_col_types == [np.int32, np.int32, np.int32] or \
arr_col_types == [np.int16, np.int16, np.int16]


def test_check_column_types_on_nested_list():
mock_dataset = [
[1, "a"],
[2, "b"],
[3, "c"]
]
mock_dataset = [[1, "a"], [2, "b"], [3, "c"]]
assert check_column_types(mock_dataset, 1.0) == [int, str]


def test_check_column_types_on_bad_dataset():
mock_dataset = [
[1, "a"],
[2, False],
[3, "c"]
]
mock_dataset2 = [[1, "a"], [2, False], [3, "c"]]
with pytest.raises(ValueError):
check_column_types(mock_dataset, 1.0)
check_column_types(mock_dataset2, 1.0)

df = pd.DataFrame({
"i": [1, 2, 3, 4, 5],
"s": ["a", "b", "c", "d", "e"]
})
df_col_types = check_column_types(df, 1.0)
assert df_col_types == [np.int64, np.object_] or \
df_col_types == [np.int32, np.object_] or \
df_col_types == [np.int16, np.object_]


def test_check_num_columns_a():
Expand Down

0 comments on commit 24d44b0

Please sign in to comment.