Skip to content

Commit

Permalink
Merge pull request #409 from rsokl/fix-failing-test
Browse files Browse the repository at this point in the history
remove bad use of wraps in test
  • Loading branch information
rsokl authored Jul 6, 2022
2 parents 24531e1 + d389927 commit 79b55fb
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions tests/tensor_ops/test_setitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,7 @@ def __init__(
def __call__(self, f: Callable) -> Callable[[], None]:
"""Wraps an empty function to populate it with the test function"""

@given(data=st.data(), x=self.array_strat)
@wraps(f)
@given(x=self.array_strat, data=st.data())
def wrapper(x: np.ndarray, data: st.DataObject):

index = data.draw(self.index_strat(x), label="index")
Expand Down Expand Up @@ -204,7 +203,7 @@ def test_setitem_multiple_input():

@given(x_constant=st.booleans(), y_constant=st.booleans(), data=st.data())
def test_setitem_sanity_check(x_constant, y_constant, data):
""" Ensure proper setitem behavior for all combinations of constant/variable Tensors"""
"""Ensure proper setitem behavior for all combinations of constant/variable Tensors"""
x = Tensor([1.0, 2.0, 3.0, 4.0], constant=x_constant)
w = 4 * x

Expand Down Expand Up @@ -261,7 +260,7 @@ def test_setitem_downstream_doesnt_affect_upstream_backprop():
@pytest.mark.parametrize("x_constant", [True, False])
@pytest.mark.parametrize("y_constant", [True, False])
def test_setitem_doesnt_mutate_upstream_nodes(x_constant: bool, y_constant: bool):
""" Ensure setitem doesn't mutate variable non-constant tensor"""
"""Ensure setitem doesn't mutate variable non-constant tensor"""
x = Tensor([1.0, 2.0], constant=x_constant)
y = Tensor([3.0, 4.0], constant=y_constant)
z = x + y
Expand Down Expand Up @@ -393,7 +392,7 @@ def test_setitem_mixed_index():
),
)
def test_setitem_broadcast_bool_index():
""" index mixes boolean and int-array indexing"""
"""index mixes boolean and int-array indexing"""


@settings(deadline=None)
Expand All @@ -409,7 +408,7 @@ def test_setitem_broadcast_bool_index():
),
)
def test_setitem_bool_basic_index():
""" index mixes boolean and basic indexing"""
"""index mixes boolean and basic indexing"""


@settings(deadline=None)
Expand All @@ -427,7 +426,7 @@ def test_setitem_bool_basic_index():
else st.floats(-10.0, 10.0).map(np.asarray),
)
def test_setitem_bool_axes_index():
""" index consists of boolean arrays specified for each axis """
"""index consists of boolean arrays specified for each axis"""


@settings(deadline=None, max_examples=1000)
Expand Down Expand Up @@ -456,4 +455,4 @@ def test_setitem_bool_axes_index():
),
)
def test_setitem_arbitrary_index():
""" test arbitrary indices"""
"""test arbitrary indices"""

0 comments on commit 79b55fb

Please sign in to comment.