Skip to content

Commit

Permalink
add void type to uop (tinygrad#6471)
Browse files Browse the repository at this point in the history
* unwrap_dtype maybe

* uopgraph stuff that hardcoded None

* test_ops passes

* dtypes.py fixups

* update test_linearizer and friends

* more ast updates

* test_beam and test_schedule too

* add void type to uop [run_process_replay]

* remove dumb casts

* start making it green

* more cast cleanups

* more cls methods to fix

* regenerate dataset

* split UOp and NOp const

* maybe that too

* fix docs

* update test_uop_symbolic

* test_verify_ast

* new sops with no diff

* meh, type_ignore is alright

* remove that assert

---------

Co-authored-by: qazal <[email protected]>
  • Loading branch information
geohot and Qazalin authored Sep 11, 2024
1 parent 1b4d182 commit bdd0c06
Show file tree
Hide file tree
Showing 22 changed files with 506 additions and 502 deletions.
4 changes: 2 additions & 2 deletions docs/abstractions2.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@
ld_2 = UOp(UOps.LOAD, dtypes.int32, (buf_2, ShapeTracker.from_shape((1,)).to_uop()))
alu = ld_1 + ld_2
output_buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), 0)
st_0 = UOp(UOps.STORE, None, (output_buf, ShapeTracker.from_shape((1,)).to_uop(), alu))
s = UOp(UOps.SINK, None, (st_0,))
st_0 = UOp(UOps.STORE, dtypes.void, (output_buf, ShapeTracker.from_shape((1,)).to_uop(), alu))
s = UOp(UOps.SINK, dtypes.void, (st_0,))

# convert the computation to a "linearized" format (print the format)
from tinygrad.engine.realize import get_kernel, CompiledRunner
Expand Down
Binary file modified extra/datasets/sops.gz
Binary file not shown.
4 changes: 2 additions & 2 deletions extra/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ def create_uop(lop:LazyOp) -> UOp:
return UOp(UOps.CONST, dtype, (st_uop,), lop.arg.val)
buf = UOp(UOps.DEFINE_GLOBAL, membuf_dtype if isinstance(membuf_dtype, ImageDType) else PtrDType(membuf_dtype), (), lop.arg.idx)
if lop.op is BufferOps.LOAD: return UOp(UOps.LOAD, dtype, (buf, st_uop))
return UOp(UOps.STORE, None, (buf, st_uop, create_uop(lop.src[0])))
return UOp(UOps.STORE, dtypes.void, (buf, st_uop, create_uop(lop.src[0])))
src = tuple(create_uop(x) for x in lop.src)
if lop.op is MetaOps.KERNEL: return UOp(UOps.SINK, None, src)
if lop.op is MetaOps.KERNEL: return UOp(UOps.SINK, dtypes.void, src)
if lop.op in ReduceOps:
alu_op = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.PROD:BinaryOps.MUL, ReduceOps.MAX:BinaryOps.MAX}[cast(ReduceOps, lop.op)]
return UOp(UOps.REDUCE_AXIS, src[0].dtype, src, (alu_op, lop.arg))
Expand Down
30 changes: 15 additions & 15 deletions test/test_linearizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Tuple, Union, cast
from typing import List, Tuple, Union
import numpy as np
import unittest
from dataclasses import replace
Expand Down Expand Up @@ -87,8 +87,8 @@ def test_multioutput(self):
g0, g1, g2, g3 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtype), arg=i) for i in range(4)]
a = UOp(UOps.LOAD, dtype, (g2, st.to_uop()))
b = UOp(UOps.LOAD, dtype, (g3, st.to_uop()))
out0 = UOp(UOps.STORE, None, (g0, st.to_uop(), a + b))
out1 = UOp(UOps.STORE, None, (g1, st.to_uop(), a * b))
out0 = UOp(UOps.STORE, dtypes.void, (g0, st.to_uop(), a + b))
out1 = UOp(UOps.STORE, dtypes.void, (g1, st.to_uop(), a * b))
sink = UOp(UOps.SINK, src=(out0, out1))

a_t = Tensor.full(st.shape, 2).contiguous().realize()
Expand All @@ -113,7 +113,7 @@ def test_multireduce(self):
second_x = UOp(UOps.LOAD, dtypes.float, (g1, st_x.reshape((32, 1)).to_uop()))
diff = second_x + first_reduce*ast_const(dtypes.float, -1, (32, 1))
second_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (0,)))
store = UOp(UOps.STORE, None, (g0, ShapeTracker.from_shape((1, 1)).to_uop(), second_reduce))
store = UOp(UOps.STORE, dtypes.void, (g0, ShapeTracker.from_shape((1, 1)).to_uop(), second_reduce))
sink = UOp(UOps.SINK, src=(store,))
opts = [
[Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2)], # grouping
Expand Down Expand Up @@ -595,10 +595,10 @@ def test_argmax_multireduce_axis0(self):
t = Tensor.randn(10, 20).realize()
t_max = t.max((0,)).realize()
real_argmax = np.argmax(t.numpy(), axis=0, keepdims=False).reshape(1, 20, 1)
ast = UOp(UOps.SINK, None, arg=None, src=(
UOp(UOps.STORE, None, arg=None, src=(
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
UOp(UOps.STORE, dtypes.void, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=0, src=()),
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(1, 20, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa E501
UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 20, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa E501
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
ast_const(dtypes.int, st=ShapeTracker(views=(View(shape=(1, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), val=10),
Expand All @@ -611,10 +611,10 @@ def test_argmax_multireduce_axis0(self):
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()),
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(20, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),)), # noqa E501
UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(20, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),)), # noqa E501
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()),
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), # noqa E501
UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), # noqa E501
ast_const(dtypes.bool, True, st=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)), # noqa E501
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (2,)), src=(
Expand All @@ -627,10 +627,10 @@ def test_argmax_multireduce_flat(self):
t = Tensor.randn(10, 20).realize()
t_max = t.max().realize()
real_argmax = np.argmax(t.numpy())
ast = UOp(UOps.SINK, None, arg=None, src=(
UOp(UOps.STORE, None, arg=None, src=(
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
UOp(UOps.STORE, dtypes.void, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=0, src=()),
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501
UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
ast_const(dtypes.int, 200, (1, 1)),
Expand All @@ -643,10 +643,10 @@ def test_argmax_multireduce_flat(self):
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()),
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(200, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),)), # noqa: E501
UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(200, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),)), # noqa: E501
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()),
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(200, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), # noqa: E501
UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(200, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), # noqa: E501
ast_const(dtypes.bool, True, (200, 1)),)),)),
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=(
Expand Down Expand Up @@ -1756,7 +1756,7 @@ def test_matvec(self):
def helper_linearizer_ast(ast:UOp, inputs:List[Tensor], *args, **kwargs):
assert isinstance(ast, UOp), "ast must be UOp"
inbufs = [x.lazydata.base.buffer for x in inputs]
outbufs = [Buffer(inbufs[-1].device if inbufs else Device.DEFAULT, out.st_arg.size, cast(DType,out.src[2].dtype)).allocate() \
outbufs = [Buffer(inbufs[-1].device if inbufs else Device.DEFAULT, out.st_arg.size, out.src[2].dtype).allocate() \
for out in ast.src]
return _helper_linearizer_opt_ast(ast, outbufs+inbufs, *args, **kwargs)

Expand Down
Loading

0 comments on commit bdd0c06

Please sign in to comment.