Skip to content

Commit

Permalink
folding for vectorized consts [run_process_replay] (tinygrad#6498)
Browse files Browse the repository at this point in the history
* folding for vectorized consts [run_process_replay]

* remove that if statement

* inf loop
  • Loading branch information
geohot authored Sep 12, 2024
1 parent a532d59 commit 327eb12
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 15 deletions.
21 changes: 10 additions & 11 deletions tinygrad/codegen/uopgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def index_collapse(idx,rng,buf,add,mul,ld,reduce):
lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(g1.dtype.count)))),
(UPat(UOps.GEP, src=(UPat(UOps.VECTORIZE, name="vec"),), name="gep"),
lambda gep, vec: UOp(UOps.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]),
(UPat(UOps.GEP, src=(UPat.cvar("c"),), name="gep"), lambda gep, c: gep.const_like(c.arg)),
(UPat(UOps.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)),
(UPat(UOps.GEP, src=(UPat(UOps.VCONST, name="c"),), name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))),
# tensor core with a 0 input is acc
*[(UPat(UOps.WMMA, src=(UPat.const(None, 0.0), UPat.var(), UPat.var("acc"))), lambda acc: acc) for i in [2, 4, 8]],
Expand Down Expand Up @@ -279,7 +279,7 @@ def index_collapse(idx,rng,buf,add,mul,ld,reduce):
(UPat(UOps.CAST, name="root", src=UPat.cvar("c")), lambda root, c: root.const_like(c.arg)),
# a conditional with the same results either way is a noop, also fold const conditionals
(UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val),
(UPat.cvar("gate").where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
(UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
# ** constant folding **
(UPat(UOps.ALU, name="root", src=UPat((UOps.VCONST, UOps.CONST))),
lambda root: root.const_like(exec_alu(root.arg, root.dtype, [x.arg for x in root.src]))),
Expand Down Expand Up @@ -311,26 +311,26 @@ def index_collapse(idx,rng,buf,add,mul,ld,reduce):
# *** rules from symbolic ***
# ** lt **
# c0*x<c1 for positive int c0,c1
((UPat.cvar("c0")*UPat.var("x")).lt(UPat.cvar("c1")),
((UPat.cvar("c0", vec=False)*UPat.var("x")).lt(UPat.cvar("c1", vec=False)),
lambda x,c0,c1: x.lt(math.ceil(c1.arg/c0.arg)) if dtypes.is_int(x.dtype) and c0.arg > 0 and c1.arg > 0 else None),
# c0*x<c1 for negative int c0 and non-positive c1
((UPat.cvar("c0")*UPat.var("x")).lt(UPat.cvar("c1")),
lambda x,c0,c1: (-x).lt(-(math.floor(-c1.arg/-c0.arg))) if dtypes.is_int(x.dtype) and c0.arg < 0 and c0.arg != -1 and c1.arg <= 0 else None),
# mul add lt
(((UPat.cvar("c0")*UPat.var("x"))+UPat.var("x2")).lt(UPat.cvar("c1")),
(((UPat.cvar("c0", vec=False)*UPat.var("x"))+UPat.var("x2")).lt(UPat.cvar("c1", vec=False)),
lambda x,x2,c0,c1: x.lt(c1//c0) if c1.arg % c0.arg == 0 and c0.arg > x2.vmax and x2.vmin >= 0 else None),
# generic lt folding
(UPat.var("x").lt(UPat.cvar("c")),
(UPat.var("x").lt(UPat.cvar("c", vec=False)),
lambda x,c: lt_folding(x, c.arg) if 0 < c.arg and dtypes.is_int(x.dtype) and not dtypes.is_unsigned(x.dtype) else None),
# ** div **
# # div folding
(UPat.var("x") // UPat.cvar("c"), lambda x,c:
(UPat.var("x") // UPat.cvar("c", vec=False), lambda x,c:
newx if 0 < c.arg and not dtypes.is_unsigned(x.dtype) and (newx:=div_folding(x,c.arg)) is not None else None),
# ** mod **
# mod folding
(UPat.var("x") % UPat.cvar("c"), lambda x,c: newx if 0 < c.arg and (newx:=mod_folding(x,c.arg)) is not None else None),
(UPat.var("x") % UPat.cvar("c", vec=False), lambda x,c: newx if 0 < c.arg and (newx:=mod_folding(x,c.arg)) is not None else None),
# mul mod
((UPat.cvar("c0")*UPat.var("x")) % UPat.cvar("c1"), lambda x,c0,c1: (x%(c1//c0))*c0 if c1.arg%c0.arg == 0 else None),
((UPat.cvar("c0", vec=False)*UPat.var("x")) % UPat.cvar("c1", vec=False), lambda x,c0,c1: (x%(c1//c0))*c0 if c1.arg%c0.arg == 0 else None),
# ** combine terms **
(UPat.var("x")%UPat.cvar("c")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"), lambda x,c: x), # (x%c)+(x//c)*c = x
(UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1)
Expand Down Expand Up @@ -364,7 +364,7 @@ def index_collapse(idx,rng,buf,add,mul,ld,reduce):
(UPat(UOps.SINK, name="root"),
lambda root: UOp(UOps.SINK, root.dtype, a, root.arg) if len(a:=tuple(x for x in root.src if x.op is not UOps.NOOP)) != len(root.src) else None),
# ** move add consts to end (NOTE: this is still happening before constant folding) **
(UPat(UOps.ALU, arg=BinaryOps.ADD, src=(UPat.cvar("c1"), UPat.var("x"))), lambda c1,x: x+c1 if x.op is not UOps.CONST else None),
(UPat(UOps.ALU, arg=BinaryOps.ADD, src=(UPat.cvar("c1"), UPat.var("x"))), lambda c1,x: x+c1 if x.op not in (UOps.CONST, UOps.VCONST) else None),
(UPat(UOps.ALU, arg=BinaryOps.ADD, src=[UPat(UOps.ALU, arg=BinaryOps.ADD, src=(UPat.var("x"), UPat.cvar("c1"))), UPat.var("y")]),
lambda x,c1,y: (x+y)+c1),
])
Expand Down Expand Up @@ -495,8 +495,7 @@ def find_gate(x:UOp) -> Optional[UOp]:
])

no_pyint = PatternMatcher([(UPat((UOps.CONST, UOps.VCONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE, UOps.EXPAND, UOps.VECTORIZE), name="x"),
lambda x: UOp(x.op, dtypes.int32.vec(x.dtype.count) if x.dtype.count > 1 else dtypes.int32, x.src, x.arg) \
if x.dtype.scalar() == dtypes.pyint else None)])
lambda x: UOp(x.op, dtypes.int32.vec(x.dtype.count), x.src, x.arg) if x.dtype.scalar() == dtypes.pyint else None)])

# *** uop graph ***

Expand Down
3 changes: 2 additions & 1 deletion tinygrad/dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ class DType:
count: int
def __repr__(self): return f"dtypes.{'_'*(c:=self.count!=1)}{INVERSE_DTYPES_DICT[self.name if not c else self.scalar().name]}{str(self.count)*c}"
def vec(self, sz:int):
assert sz > 1 and self.count == 1, f"can't vectorize {self} with size {sz}"
assert self.count == 1, f"can't vectorize {self} with size {sz}"
if sz == 1 or self.name == 'void': return self # void doesn't vectorize, and sz=1 is scalar
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz)
def scalar(self) -> DType: return DTYPES_DICT[self.name[:-len(str(self.count))]] if self.count > 1 else self

Expand Down
4 changes: 2 additions & 2 deletions tinygrad/engine/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ def log_lazybuffer(lb:'LazyBuffer', scheduled=False):
# realized but unseen?
G.add_node(nm(lb), label=f'"{str(lb.base.realized)[5:-1].replace(" ", chr(10))}\nb:{nm(lb.realized)}"', style='filled', fillcolor="#f0c08080")

uops_colors = {UOps.ALU: "#ffffc0", UOps.LOAD: "#ffc0c0", UOps.STORE: "#c0ffc0", UOps.SPECIAL: "#c0c0ff", UOps.CONST: "#e0e0e0",
uops_colors = {UOps.ALU: "#ffffc0", UOps.LOAD: "#ffc0c0", UOps.STORE: "#c0ffc0", UOps.CONST: "#e0e0e0", UOps.VCONST: "#e0e0e0",
UOps.DEFINE_GLOBAL: "#ffe0b0", UOps.DEFINE_LOCAL: "#ffe0d0", UOps.DEFINE_ACC: "#f0ffe0", UOps.REDUCE: "#C4A484",
UOps.RANGE: "#c8a0e0", UOps.ASSIGN: "#e0ffc0", UOps.BARRIER: "#ff8080", UOps.IF: "#c8b0c0",
UOps.RANGE: "#c8a0e0", UOps.ASSIGN: "#e0ffc0", UOps.BARRIER: "#ff8080", UOps.IF: "#c8b0c0", UOps.SPECIAL: "#c0c0ff",
UOps.SWIZZLE: "#7ACD93", UOps.SHAPETRACKER: "#C8F9D4", UOps.REDUCE_AXIS: "#f58488"}
graph_uops_cnt = 0
def word_wrap(x, wrap=80): return x if len(x) <= wrap else (x[0:wrap] + "\n" + word_wrap(x[wrap:], wrap))
Expand Down
3 changes: 2 additions & 1 deletion tinygrad/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,8 @@ def __init__(self, op:Optional[Union[UOps, Tuple[UOps, ...]]]=None, dtype:Option
def var(name:Optional[str]=None, dtype:Optional[DType]=None): return UPat(dtype=dtype, name=name)
@staticmethod
@functools.lru_cache(None)
def cvar(name:Optional[str]=None, dtype:Optional[DType]=None): return UPat(UOps.CONST, dtype=dtype, name=name)
def cvar(name:Optional[str]=None, dtype:Optional[DType]=None, vec=True):
return UPat((UOps.CONST, UOps.VCONST) if vec else UOps.CONST, dtype=dtype, name=name)
@staticmethod
@functools.lru_cache(None)
def const(dtype:Optional[DType], b:ConstType|Variable): return UPat(UOps.CONST, dtype=dtype, arg=b)
Expand Down

0 comments on commit 327eb12

Please sign in to comment.