Skip to content

Commit

Permalink
cleanup some scheduler rewrites [run_process_replay] (tinygrad#6474)
Browse files Browse the repository at this point in the history
  • Loading branch information
Qazalin authored Sep 11, 2024
1 parent 1caddde commit d6d9234
Showing 1 changed file with 8 additions and 13 deletions.
21 changes: 8 additions & 13 deletions tinygrad/engine/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,27 +108,22 @@ def permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTr
tmp = input_st.permute(permute_axis)
return tmp, tmp.shape[-len(axis):]

def swizzle_reduceop(input_st:ShapeTracker, swizzle:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTracker, Tuple[int, ...]]:
# push the movementop to the buffer uop
tmp, rshape = permute_reduce(input_st, axis)
# ***** reduceop fusor *****

def push_swizzle_up_through_reduce(swizzle:UOp, reduceop:UOp) -> Optional[UOp]:
if swizzle.arg.contiguous: return None
rsrc = reduceop.src[0]
tmp, rshape = permute_reduce(ShapeTracker.from_shape(unwrap(rsrc.st).shape), reduceop.arg[1])
prshape = prod(rshape)
strides = strides_for_shape(rshape)
nv: List[View] = []
for v in swizzle.views:
for v in swizzle.arg.views:
nv.append(View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides,
v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None))
# update input_st and axis
new_input_st = tmp + ShapeTracker(tuple(nv))
_, new_rshape = permute_reduce(new_input_st, axis)
_, new_rshape = permute_reduce(new_input_st, reduceop.arg[1])
new_axis = tuple(range(len(new_input_st.shape)-len(new_rshape), len(new_input_st.shape)))
return new_input_st, new_axis

# ***** reduceop fusor *****

def push_swizzle_up_through_reduce(swizzle:UOp, reduceop:UOp) -> Optional[UOp]:
if swizzle.arg.contiguous: return None
rsrc = reduceop.src[0]
new_input_st, new_axis = swizzle_reduceop(ShapeTracker.from_shape(unwrap(rsrc.st).shape), swizzle.arg, reduceop.arg[1])
return UOp(UOps.SWIZZLE, reduceop.dtype, (UOp(UOps.REDUCE_AXIS, reduceop.dtype, (st_fixup(rsrc, lambda st:st+new_input_st, {}),),
(reduceop.arg[0], new_axis)),), ShapeTracker.from_shape(swizzle.arg.shape))

Expand Down

0 comments on commit d6d9234

Please sign in to comment.