Skip to content

Commit

Permalink
enable graph rewrite in the scheduler (tinygrad#6249)
Browse files Browse the repository at this point in the history
* test: enable

* skip those

* skip pads tests
  • Loading branch information
Qazalin authored Sep 11, 2024
1 parent d9d1ae7 commit 3cde150
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 65 deletions.
5 changes: 4 additions & 1 deletion test/test_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1317,6 +1317,7 @@ def test_simple_indexing(self):
self.check_schedule(xt, 2)
np.testing.assert_equal(xt.numpy(), X.numpy()[idxs.numpy()])

@unittest.skip("TODO: support pads in graph_rewrite")
def test_simple_indexing_alt(self):
X = Tensor.arange(16).reshape(4, 4)
xt = X[[1, 2], [1, 2]]
Expand All @@ -1337,6 +1338,7 @@ def test_advanced_indexing_alt(self):
self.check_schedule(xt, 6)
np.testing.assert_equal(xt.numpy(), 6)

@unittest.skip("TODO: support pads in graph_rewrite")
def test_advanced_simple_indexing_combined(self):
X = Tensor.arange(16).reshape(4, 4)
xt = X[1:2, [1, 2]]
Expand Down Expand Up @@ -1468,7 +1470,8 @@ def test_arange_expand_copy(self):
self.check_schedule(a, 1)
np.testing.assert_equal(a.numpy(), [[[0, 0], [1, 1]], [[2, 2], [3, 3]]])

@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
@unittest.skip("TODO: support pads in graph_rewrite")
#@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_precompute_freqs_cis(self):
args = {"dim":32 if CI else 128, "end":2048 if CI else 8192, "theta":10000, "dtype":dtypes.half}
fused = precompute_freqs_cis(**args)
Expand Down
70 changes: 7 additions & 63 deletions tinygrad/engine/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def __hash__(self):

def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], inputs:Dict[LazyBuffer, int],
realizes:Dict[LazyBuffer, None], assign_targets:Dict[LazyBuffer, LazyBuffer],
reduce_info:Dict[Tuple[LazyBuffer, ShapeTracker], Tuple[ShapeTracker, Tuple[int, ...]]],
cache:Dict[Tuple[LazyBuffer, ShapeTracker], UOp]) -> UOp:
"""recursively create a UOp"""
if buf is not buf.base: st, buf = buf.st+st, buf.base
Expand Down Expand Up @@ -79,58 +78,19 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..

# reduce ops change ShapeTracker
if buf.op in ReduceOps:
alu_op = REDUCE_ALU[cast(ReduceOps, buf.op)]
if not AST_REWRITE:
rinfo = reduce_info.get((buf, st))
rsrc = _recursive_uop(buf.srcs[0], st:=(rinfo[0] if rinfo else st), outputs, var_vals, inputs, realizes, assign_targets, reduce_info, cache)
# if we are merging the reduce, skip it
if rinfo is None:
assert rsrc.op is UOps.REDUCE_AXIS and rsrc.arg[0] is alu_op, f"can't merge reduceop {buf.op} with {rsrc}\n{st}"
return rsrc
return cache.setdefault((buf, st), UOp(UOps.REDUCE_AXIS, dtype, (rsrc,), (alu_op, rinfo[1])))
# this is the new reduceop swizzler with graph_rewrite
input_st = ShapeTracker.from_shape(buf.srcs[0].shape)
rsrc = _recursive_uop(buf.srcs[0], input_st, outputs, var_vals, inputs, realizes, assign_targets, reduce_info, cache)
ret = UOp(UOps.REDUCE_AXIS, dtype, (rsrc,), (alu_op, buf.arg))
rsrc = _recursive_uop(buf.srcs[0], ShapeTracker.from_shape(buf.srcs[0].shape), outputs, var_vals, inputs, realizes, assign_targets, cache)
ret = UOp(UOps.REDUCE_AXIS, dtype, (rsrc,), (REDUCE_ALU[cast(ReduceOps, buf.op)], buf.arg))
return cache.setdefault((buf, st), UOp(UOps.SWIZZLE, dtype, (ret,), st))

# elementwise ops pass shapetracker
in_uops = tuple(_recursive_uop(x, st, outputs, var_vals, inputs, realizes, assign_targets, reduce_info, cache) for x in buf.srcs)
in_uops = tuple(_recursive_uop(x, st, outputs, var_vals, inputs, realizes, assign_targets, cache) for x in buf.srcs)
if buf.op in {MetaOps.CONTIGUOUS, MetaOps.ASSIGN}:
assert buf in outputs, f"{buf.op} must be writable"
return in_uops[0]
if buf.op is UnaryOps.CAST: return cache.setdefault((buf, st), UOp(UOps.CAST, dtype, in_uops))
if buf.op is UnaryOps.BITCAST: return cache.setdefault((buf, st), UOp(UOps.BITCAST, dtype, in_uops))
return cache.setdefault((buf, st), UOp(UOps.ALU, dtype, in_uops, buf.op))

def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer, None], outs:List[LazyBuffer],
reduce_info:Dict[Tuple[LazyBuffer, ShapeTracker], Tuple[ShapeTracker, Tuple[int, ...]]],
cache:Dict[Tuple[LazyBuffer, ShapeTracker], Optional[Tuple[LazyBuffer, ShapeTracker]]]) -> \
Optional[Tuple[LazyBuffer, ShapeTracker]]:
if (buf, st) in cache: return cache[(buf, st)]
if buf.base.realized is not None or (buf.base in realizes and buf.base not in outs): return None
if buf is not buf.base: st, buf = buf.st+st, buf.base
input_st = ShapeTracker.from_shape(buf.srcs[0].shape) if buf.op in ReduceOps else st
reduce_srcs = [r for x in buf.srcs if (r:=_recurse_reduceops(x, input_st, realizes, outs, reduce_info, cache)) is not None]
top_reduce = reduce_srcs[-1] if len(reduce_srcs) != 0 else None
if buf.op in ReduceOps:
axis = buf.arg
if not st.contiguous: input_st, axis = swizzle_reduceop(input_st, st, axis)
elif top_reduce is not None:
top_reduce_input_st, top_reduce_axes = reduce_info[top_reduce]
if buf.srcs[0] is not buf.srcs[0].base and buf.srcs[0].base is top_reduce[0] and buf.op is top_reduce[0].op:
# merge this reduce with its parent
new_st = top_reduce[1]+st
top_reduce = (top_reduce[0], new_st.reshape(top_reduce_input_st.reduce(new_axis:=axis+top_reduce_axes)))
reduce_info[top_reduce] = (top_reduce_input_st, new_axis)
return None
# reshape this reduceop based on the top reduce
input_st = input_st.reshape(tuple(1 if i in top_reduce_axes else s for i,s in enumerate(top_reduce_input_st.shape)))
st = st.reshape(input_st.reduce(axis))
reduce_info[(buf, st)] = (input_st, axis)
return (buf, st)
return cache.setdefault((buf, st), top_reduce)

# ***** helpers for doing movementops on uops *****

def st_fixup(u:UOp, apply_to_st:Callable[[ShapeTracker], ShapeTracker], cache:Dict[UOp, UOp]) -> UOp:
Expand Down Expand Up @@ -213,40 +173,24 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) ->
return [LBScheduleItem(UOp(UOps.SINK, None, (wr,)), outs, [x.base for x in out.srcs])]
if out.op in {MetaOps.CUSTOM, MetaOps.COPY, MetaOps.EMPTY, MetaOps.VIEW}:
return [LBScheduleItem(UOp(UOps.EXT, out.dtype, (), (out.op, out.arg)), outs, [x.base for x in out.srcs])]
reduce_info: Dict[Tuple[LazyBuffer, ShapeTracker], Tuple[ShapeTracker, Tuple[int, ...]]] = {}
if not AST_REWRITE:
# push through all movementops between reduceops
# NOTE: AST_REWRITE does this with graph rewrite
seen_ops: Dict[Tuple[LazyBuffer, ShapeTracker], Optional[Tuple[LazyBuffer, ShapeTracker]]] = {}
for out in outs: _recurse_reduceops(out, out.st, realizes, outs, reduce_info, seen_ops)
# pad all reduceops to the max of each dimension
shape_dims = [sorted(dedup(dims)) for dims in zip(*[input_st.shape for input_st,_ in reduce_info.values()])]
for i,dims in enumerate(shape_dims):
if len(dims) == 1 or (len(dims) == 2 and dims[0] == 1): continue
for (r,view),(input_st,axis) in reduce_info.items():
if (dim:=input_st.shape[i]) > 1 and dim != max(dims):
input_st = input_st.pad(((0, 0),)*i+((0, max(dims)-dim),))
reduce_info[(r, view)] = (input_st, axis)
# create the stores
var_vals = merge_dicts([out.st.var_vals.copy() for out in outs])
assign_targets = {x.srcs[1]:x for x in outs if x.op is MetaOps.ASSIGN}
cache: Dict[Tuple[LazyBuffer, ShapeTracker], UOp] = {}
ast: List[UOp] = []
inputs: Dict[LazyBuffer, int] = {}
for i, out in enumerate(outs):
output_shape = ShapeTracker.reduce(*deque(reduce_info.values(), 1).pop()) if reduce_info and not AST_REWRITE else out.shape
output_st = ShapeTracker.from_shape(output_shape)
src = _recursive_uop(out, output_st, tuple(outs), var_vals, inputs, realizes, assign_targets, reduce_info, cache=cache)
output_st = ShapeTracker.from_shape(out.shape)
src = _recursive_uop(out, output_st, tuple(outs), var_vals, inputs, realizes, assign_targets, cache=cache)
if out.op is MetaOps.ASSIGN and out.arg:
assert out.arg[0].shape == out.shape, f"ASSIGN must not override output shape {out.arg[0].shape} != {out.shape}"
output_st = out.arg[0].reshape(output_shape)
output_st = out.arg[0].reshape(out.shape)
output_st, vv = output_st.simplify().unbind()
if vv: var_vals.update(vv)
ubuf = UOp(UOps.DEFINE_GLOBAL, out.dtype if isinstance(out.dtype, ImageDType) else PtrDType(out.dtype), (), i)
ast.append(UOp(UOps.STORE, None, (ubuf, output_st.to_uop(), src)))
sink = UOp(UOps.SINK, None, tuple(ast))
if AST_REWRITE:
sink = graph_rewrite(sink, reduceop_fusor)
if AST_REWRITE: sink = graph_rewrite(sink, reduceop_fusor)
return [LBScheduleItem(sink, outs, list(inputs), var_vals, dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs]))]

# *** DAG creation: decide which LazyBuffers should realize ***
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def __lt__(self, x): return self.value < x
MULTIOUTPUT, PROFILE, PROFILEPATH = ContextVar("MULTIOUTPUT", 1), ContextVar("PROFILE", 0), ContextVar("PROFILEPATH", temp("tinygrad_profile.json"))
USE_TC, TC_OPT, AMX, TRANSCENDENTAL = ContextVar("TC", 1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0), ContextVar("TRANSCENDENTAL", 1)
FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_BW", 0)
SPLIT_REDUCEOP, AST_REWRITE = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("AST_REWRITE", 0)
SPLIT_REDUCEOP, AST_REWRITE = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("AST_REWRITE", 1)

@dataclass(frozen=True)
class Metadata:
Expand Down

0 comments on commit 3cde150

Please sign in to comment.