Skip to content

Commit

Permalink
green conv bw AST_REWRITE=1 (tinygrad#6466)
Browse files Browse the repository at this point in the history
* green conv bw AST_REWRITE=1

* new strides and dtype fix
  • Loading branch information
Qazalin authored Sep 11, 2024
1 parent 15c4d4f commit 262569a
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
5 changes: 2 additions & 3 deletions test/test_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1289,7 +1289,6 @@ def test_reduceop_reshape_dont_push(self):

def test_conv2d(self): _test_conv2d(8)
def test_conv2d_fused(self): _test_conv2d(7, FUSE_CONV_BW=1)
@unittest.expectedFailure
def test_conv2d_fused_ast_rewrite(self): _test_conv2d(7, FUSE_CONV_BW=1, AST_REWRITE=1)

@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
Expand All @@ -1299,7 +1298,7 @@ def test_conv2d_half(self): _test_conv2d(8, dtype=dtypes.half)
def test_conv2d_fused_half(self): _test_conv2d(7, dtype=dtypes.half)
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
@unittest.expectedFailure
def test_conv2d_fused_ast_rewrite_half(self): _test_conv2d(7, FUSE_CONV_BW=1, AST_REWRITE=1)
def test_conv2d_fused_ast_rewrite_half(self): _test_conv2d(7, FUSE_CONV_BW=1, AST_REWRITE=1, dtype=dtypes.half)

class TestIndexing(unittest.TestCase):
def check_schedule(self, xt:Union[Tensor,List[Tensor]], cnt:int):
Expand Down Expand Up @@ -1706,7 +1705,7 @@ def test_swizzle_rewrite_alt(self):
# and pushed to the LOAD
new_load_st = unwrap([x for x in ret.parents if x.op is UOps.SHAPETRACKER][0].st)
self.assertGreater(prod(new_load_st.shape), prod(ld_st.shape))
self.assertEqual(new_load_st.views[0].strides, (0, 0, 3, 0, 1, 0, 0))
self.assertEqual(new_load_st.views[0].strides, (0, 9, 3, 0, 1, 0, 27))

if __name__ == '__main__':
unittest.main(verbosity=2)
2 changes: 1 addition & 1 deletion tinygrad/engine/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def swizzle_reduceop(input_st:ShapeTracker, swizzle:ShapeTracker, axis:Tuple[int
def push_swizzle_up_through_reduce(swizzle:UOp, reduceop:UOp) -> Optional[UOp]:
if swizzle.arg.contiguous: return None
rsrc = reduceop.src[0]
new_input_st, new_axis = swizzle_reduceop(unwrap(rsrc.st), swizzle.arg, reduceop.arg[1])
new_input_st, new_axis = swizzle_reduceop(ShapeTracker.from_shape(unwrap(rsrc.st).shape), swizzle.arg, reduceop.arg[1])
return UOp(UOps.SWIZZLE, reduceop.dtype, (UOp(UOps.REDUCE_AXIS, reduceop.dtype, (st_fixup(rsrc, lambda st:st+new_input_st, {}),),
(reduceop.arg[0], new_axis)),), ShapeTracker.from_shape(swizzle.arg.shape))

Expand Down

0 comments on commit 262569a

Please sign in to comment.