Skip to content

Commit

Permalink
things we can delete after dtypes.void [run_process_replay] (tinygrad…
Browse files Browse the repository at this point in the history
  • Loading branch information
Qazalin authored Sep 11, 2024
1 parent bce73c9 commit dda5c63
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 24 deletions.
6 changes: 2 additions & 4 deletions tinygrad/codegen/uopgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def fold_expanded(ex, buf):
# first, extract all the relevant offsets
offsets_rootsrc: DefaultDict[Any, dict] = defaultdict(dict)
for i,s in enumerate(new_srcs):
if (s.dtype is not None and s.dtype.count != 1) or (is_image and s.src[1].dtype.count == 2): continue
if s.dtype.count != 1 or (is_image and s.src[1].dtype.count == 2): continue
idx = s.src[1]
if idx.arg is BinaryOps.ADD and idx.src[1].op is UOps.CONST: root_src, arg = idx.src[0], idx.src[1].arg
elif idx.op is UOps.CONST: root_src, arg = "CONST", idx.arg
Expand Down Expand Up @@ -408,7 +408,6 @@ def do_reduce(root:UOp):
reduce_parented, reduce_unparented = partition(root.src[1:], lambda x: x in root.src[0].sparents)
ret = root.src[0]
if len(reduce_parented):
assert root.dtype is not None
acc = UOp(UOps.DEFINE_ACC, root.dtype,
(root.const_like(identity_element(root.arg, root.dtype.scalar())),) + tuple(reduce_parented), (acc_number,))
acc_number += 1
Expand All @@ -420,7 +419,6 @@ def do_reduce(root:UOp):

def do_contract(con:UOp):
ex = con.src[0]
assert con.dtype is not None
# CONTRACT without EXPAND repeats the element VECTORIZED
if ex.op is not UOps.EXPAND: return UOp(UOps.VECTORIZE, con.dtype, con.src*con.dtype.count)
# CONTRACT may remove several axes from EXPAND
Expand Down Expand Up @@ -487,7 +485,7 @@ def find_gate(x:UOp) -> Optional[UOp]:

no_pyint = PatternMatcher([(UPat({UOps.CONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE, UOps.EXPAND, UOps.VECTORIZE}, name="x"),
lambda x: UOp(x.op, dtypes.int32.vec(x.dtype.count) if x.dtype.count > 1 else dtypes.int32, x.src, x.arg) \
if x.dtype is not None and x.dtype.scalar() == dtypes.pyint else None)])
if x.dtype.scalar() == dtypes.pyint else None)])

# *** uop graph ***

Expand Down
10 changes: 3 additions & 7 deletions tinygrad/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,8 +525,8 @@ def type_verify(uops):
assert dtype is not None and dtype == dtype.scalar(), f"consts must be scalar, got {dtype}"
# TODO: intermediate CONST of Variable is DEFINE_VAR
assert (isinstance(arg, Variable) and u.src) or (type(arg) is type(dtypes.as_const(arg, dtype))), f"type of {arg=} does not match {dtype}"
if uop is UOps.DEFINE_ACC: assert dtype is not None and src[0].dtype == dtype, f"dtype mismatch {src[0].dtype=} != {dtype=}"
if uop in {UOps.CAST, UOps.BITCAST, UOps.VECTORIZE}: assert arg is None and dtype is not None # type is the output type, not an arg
if uop is UOps.DEFINE_ACC: assert dtype != dtypes.void and src[0].dtype == dtype, f"dtype mismatch {src[0].dtype=} != {dtype=}"
if uop in {UOps.CAST, UOps.BITCAST, UOps.VECTORIZE}: assert arg is None and dtype != dtypes.void # type is the output type, not an arg
if uop is UOps.CAST: assert dtype.count == 1 and len(src) == 1
if uop is UOps.VECTORIZE:
assert dtype.count > 1 and len(src) == dtype.count, f"dtype vectorization mismatch {dtype.count=} != {len(src)=}"
Expand Down Expand Up @@ -561,7 +561,7 @@ def type_verify(uops):
def print_uops(uops:List[UOp]):
for i,u in enumerate(uops):
formatted_parents = [uops.index(x) if x.op is not UOps.CONST else f"{x.arg}" for x in u.src]
print(f"{i:4d} {str(u.op):20s}: {str(u.dtype) if u.dtype is not None else '':25s} " f"{str(formatted_parents):32s} {u.arg}")
print(f"{i:4d} {str(u.op):20s}: {str(u.dtype):25s} " f"{str(formatted_parents):32s} {u.arg}")

def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
flops: sint = 0
Expand All @@ -588,16 +588,12 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
elif u.op is UOps.SPECIAL:
mults *= u.arg[1] # NOTE: we don't push to the mult_stack here, you can't end these
elif u.op is UOps.LOAD:
assert u.dtype is not None
mem += u.dtype.itemsize * mults
elif u.op is UOps.STORE:
assert u.src[2].dtype is not None
mem += u.src[2].dtype.itemsize * mults
elif u.op is UOps.ALU and u not in dont_count:
assert u.dtype is not None
flops += (mults * (2 if u.arg == TernaryOps.MULACC else 1)) * u.dtype.count
elif u.op is UOps.WMMA and u not in dont_count:
assert u.arg[1] is not None
flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
return flops, mem

Expand Down
5 changes: 0 additions & 5 deletions tinygrad/renderer/assembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ def _cast(a, dtype:DType, atype:DType, bitcast=False, u=None, pred=False):
for u in uops:
uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
if uop is UOps.IF:
assert src[0].dtype is not None
kk(*self.render_bra(f"IF_{r[src[0]][1:]}_{uops.index(u)}", _cast(r[src[0]], dtypes.bool, src[0].dtype, u=u, pred=True)))
elif uop is UOps.BARRIER and self.barrier: kk(self.barrier)
elif uop is UOps.ENDRANGE:
Expand All @@ -185,7 +184,6 @@ def _cast(a, dtype:DType, atype:DType, bitcast=False, u=None, pred=False):
elif uop is UOps.ENDIF:
kk(f"IF_{r[src[0].src[0]][1:]}_{uops.index(src[0])}:")
elif uop is UOps.STORE:
assert src[0].dtype is not None and src[2].dtype is not None
assert src[0].dtype == dtypes.int64, "store isn't int64"
assert src[1].op is UOps.CONST, f"store isn't const {u}"
mem_type = '.shared' if src[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in src[0].parents) else '.global'
Expand All @@ -196,10 +194,8 @@ def _cast(a, dtype:DType, atype:DType, bitcast=False, u=None, pred=False):
kk(*self.render_store(r[src[0]], r[src[2]], src[2].dtype,
gate=r[src[3]] if len(src)>3 and src[3].op is not UOps.IF else None, ss=mem_type, offset=src[1].arg))
else:
assert dtype is not None, f"None dtype for uop {uop}"
if uop is UOps.RANGE: kk(*self.render_loop(loop:=ssa('ridx', u), r[src[0]], "LOOP_"+loop[1:]))
elif uop is UOps.ALU:
assert src[0].dtype is not None
src_dtype = src[0].dtype if args in {BinaryOps.CMPLT, BinaryOps.CMPNE} else dtype
kk(self.code_for_op[args](ssa("alu", u), *[r[x] for x in src], src_dtype, self.types[src_dtype]))
elif uop is UOps.DEFINE_ACC:
Expand Down Expand Up @@ -240,7 +236,6 @@ def _cast(a, dtype:DType, atype:DType, bitcast=False, u=None, pred=False):
# NOTE: casting to str is fine because you can't vectorize a vectorize
elif uop is UOps.VECTORIZE: r[u] = [cast(str,r[x]) for x in src]
elif uop in {UOps.CAST, UOps.BITCAST}:
assert src[0].dtype is not None and dtype.count == 1
_cast(r[src[0]], dtype, src[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
elif uop is UOps.DEFINE_LOCAL:
# TODO: we should sum these, and fetch 0xC000 from somewhere
Expand Down
9 changes: 3 additions & 6 deletions tinygrad/renderer/cstyle.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,11 @@ def ssa(prefix:str, u:Optional[UOp]=None):
depth -= 1
kk("}")
elif uop is UOps.STORE:
assert src[0].dtype is not None and src[2].dtype is not None
# mark DEFINE_GLOBAL buf as writable
if src[0].op is UOps.DEFINE_GLOBAL: bufs[src[0]] = (bufs[src[0]][0], (bufs[src[0]][1][0], True))
rendered_store = self.render_store(r[src[0]], src[0].dtype, r[src[2]], src[2].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL)
kk(f"if ({r[src[3]]}) {{ {rendered_store} }}" if len(src) > 3 and src[3].op is not UOps.IF else rendered_store)
else:
assert dtype is not None, f"None dtype for uop {uop}"
if uop is UOps.RANGE:
kk(f"for (int {(expr := ssa('ridx',u))} = {r[src[0]]}; {expr} < {r[src[1]]}; {expr}++) {{")
depth += 1
Expand Down Expand Up @@ -179,7 +177,6 @@ def ssa(prefix:str, u:Optional[UOp]=None):
elif uop is UOps.DEFINE_ACC: kk(f"{self.render_dtype(dtype)} {ssa('acc',u)} = {r[src[0]]};")
elif uop is UOps.CONST: r[u] = self.render_const(args, dtype) if args >= 0 else f"({self.render_const(args, dtype)})"
elif uop is UOps.GEP:
assert src[0].dtype is not None
from_ssa = src[0].op in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC}
r[u] = (r[src[0]] if from_ssa else f"{(r[src[0]])}") + \
(f"[{args}]" if src[0].dtype.count > (8 if self.device in {"CUDA", "NV"} else 4) or self.device == 'CLANG' else f".{'xyzwabcd'[args]}")
Expand Down Expand Up @@ -211,7 +208,7 @@ def render_vector_prefix(self, dt:DType) -> str:
return f"typedef {self.render_dtype(dt.scalar())} {self.render_dtype(dt)} __attribute__((aligned({(sz:=dt.itemsize)}),vector_size({sz})));"

def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
prefix, macros = [self.render_vector_prefix(dt) for dt in dedup(uop.dtype for uop in uops if uop.dtype is not None and uop.dtype.count>1)], []
prefix, macros = [self.render_vector_prefix(dt) for dt in dedup(uop.dtype for uop in uops if uop.dtype.count>1)], []
# https://github.com/corsix/amx
for name, (N, M, K), dtype_in, _, _, _, _, _ in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]):
macros = [
Expand Down Expand Up @@ -335,7 +332,7 @@ def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
# TODO: why is dtypes.bfloat16.name == "__bf16"? would be easier not override dtypes.name
prefix = ["#define INFINITY (__int_as_float(0x7f800000))","#define NAN (__int_as_float(0x7fffffff))"]

for dtype in dedup(uop.dtype for uop in uops if uop.dtype is not None and uop.dtype in (dtypes.half, dtypes.bfloat16)):
for dtype in dedup(uop.dtype for uop in uops if uop.dtype in {dtypes.half, dtypes.bfloat16}):
prefix += [f"#include <cuda_{'fp' if dtype == dtypes.half else 'bf'}16.h>"] + [self.render_vector_prefix(dtype.vec(sz)) for sz in [4, 8]]

# TODO: this has to be way better to generate for arbitrary M,N,K: use arg[1] for MNK, use arg[4] for vec sizes, encode register packing
Expand Down Expand Up @@ -420,7 +417,7 @@ def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
static inline __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) { return ((float)a) == ((float)b); }
""")

for dtype in dedup(uop.dtype for uop in uops if uop.dtype is not None and uop.dtype.count > 1): prefix.append(self.render_vector_prefix(dtype))
for dtype in dedup(uop.dtype for uop in uops if uop.dtype.count > 1): prefix.append(self.render_vector_prefix(dtype))

for arg in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
if arg[3] == dtypes.float: prefix.append(f"#define __{arg[0]} __builtin_amdgcn_wmma_f32_16x16x16_f16_w32")
Expand Down
3 changes: 1 addition & 2 deletions tinygrad/renderer/llvmir.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def render(self, name:str, uops:List[UOp]) -> str:
buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}

# create llvm function
func_dtypes = [(dtype_to_llvm_dtype[dtype],dtype) for dtype in buf_to_dtype.values() if dtype is not None]
func_dtypes = [(dtype_to_llvm_dtype[dtype],dtype) for dtype in buf_to_dtype.values()]
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() if isinstance(dt, PtrDType) else x for x,dt in func_dtypes]), name=name)
for a in func.args:
if a.type.is_pointer: a.add_attribute("noalias")
Expand Down Expand Up @@ -105,7 +105,6 @@ def render(self, name:str, uops:List[UOp]) -> str:
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}")))
bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[src[0].src[1]]), loop_entry_bb, bb[-1].block)
else:
assert dtype is not None, f"None dtype for uop {uop}"
if uop is UOps.RANGE:
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{len(loop_blocks)}")))
bb[-2].branch(bb[-1].block)
Expand Down

0 comments on commit dda5c63

Please sign in to comment.