From 1ff88415bb74838d3e30bac553f5f14f41dc1a4f Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Sun, 3 Dec 2023 19:19:59 -0800 Subject: [PATCH 1/3] Rename lang.types to _support.indexing --- .../kernel/{lang/types.py => _support/indexing.py} | 0 python/shark_turbine/kernel/_support/tracing.py | 2 +- python/shark_turbine/kernel/lang/__init__.py | 9 ++++++++- tests/kernel/{types_test.py => indexing_test.py} | 2 +- 4 files changed, 10 insertions(+), 3 deletions(-) rename python/shark_turbine/kernel/{lang/types.py => _support/indexing.py} (100%) rename tests/kernel/{types_test.py => indexing_test.py} (97%) diff --git a/python/shark_turbine/kernel/lang/types.py b/python/shark_turbine/kernel/_support/indexing.py similarity index 100% rename from python/shark_turbine/kernel/lang/types.py rename to python/shark_turbine/kernel/_support/indexing.py diff --git a/python/shark_turbine/kernel/_support/tracing.py b/python/shark_turbine/kernel/_support/tracing.py index 2f195e0f9..2715a6107 100644 --- a/python/shark_turbine/kernel/_support/tracing.py +++ b/python/shark_turbine/kernel/_support/tracing.py @@ -7,7 +7,7 @@ import torch.fx as fx -from ..lang.types import ( +from .indexing import ( KernelBuffer, ) diff --git a/python/shark_turbine/kernel/lang/__init__.py b/python/shark_turbine/kernel/lang/__init__.py index 332381984..87ead9569 100644 --- a/python/shark_turbine/kernel/lang/__init__.py +++ b/python/shark_turbine/kernel/lang/__init__.py @@ -1,2 +1,9 @@ from .prims import * -from .types import * + +# Include publics from the _support library. +from .._support.indexing import ( + Grid, + KernelBuffer, + SymbolDef, + sym, +) \ No newline at end of file diff --git a/tests/kernel/types_test.py b/tests/kernel/indexing_test.py similarity index 97% rename from tests/kernel/types_test.py rename to tests/kernel/indexing_test.py index e824c9478..13daac67c 100644 --- a/tests/kernel/types_test.py +++ b/tests/kernel/indexing_test.py @@ -2,7 +2,7 @@ import torch -from shark_turbine.kernel.lang.types import * +from shark_turbine.kernel._support.indexing import * M = sym.M N = sym.N From 4ac93a42bbb9b279ece1c8e67dabf4302db84bf5 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Sun, 3 Dec 2023 22:40:57 -0800 Subject: [PATCH 2/3] Implement function emission. --- .../shark_turbine/kernel/_support/indexing.py | 244 +++++++++++++++--- .../shark_turbine/kernel/_support/typing.py | 2 - .../shark_turbine/kernel/compiler/builder.py | 26 ++ python/shark_turbine/kernel/compiler/ir.py | 17 ++ .../kernel/compiler/vector_codegen.py | 149 +++++++++++ python/shark_turbine/kernel/gen/thread.py | 7 +- python/shark_turbine/kernel/lang/__init__.py | 5 +- tests/kernel/indexing_test.py | 4 + tests/kernel/vector_codegen_test.py | 47 ++++ 9 files changed, 466 insertions(+), 35 deletions(-) delete mode 100644 python/shark_turbine/kernel/_support/typing.py create mode 100644 python/shark_turbine/kernel/compiler/builder.py create mode 100644 python/shark_turbine/kernel/compiler/ir.py create mode 100644 python/shark_turbine/kernel/compiler/vector_codegen.py create mode 100644 tests/kernel/vector_codegen_test.py diff --git a/python/shark_turbine/kernel/_support/indexing.py b/python/shark_turbine/kernel/_support/indexing.py index 650e8a526..a60ac78d6 100644 --- a/python/shark_turbine/kernel/_support/indexing.py +++ b/python/shark_turbine/kernel/_support/indexing.py @@ -1,15 +1,71 @@ -from typing import ClassVar, Optional, Type, TypeVar, Union, cast +from typing import Any, ClassVar, Optional, Type, TypeVar, Union, cast + +from abc import ABC, abstractmethod +from enum import Enum +import threading + import torch __all__ = [ "KernelBuffer", "Grid", + "InputBuffer", + "OutputBuffer", "SymbolDef", + "TemporaryBuffer", "sym", ] -Grid = tuple[int, ...] +_tls = threading.local() + + +class NotSetType: + ... + +NotSet = NotSetType() + +SubtypeT = TypeVar("SubtypeT") + +############################################################################### +# ElementType +############################################################################### + + +class ElementType(ABC): + @staticmethod + def cast(something) -> "ElementType": + if isinstance(something, torch.dtyp): + return TorchElementType(something) + else: + raise TypeError( + f"Cannot convert {something} (of type {type(something)}) to an element type" + ) + + @abstractmethod + def ir_type_asm(self) -> str: + ... + + +class TorchElementType(ElementType): + def __init__(self, dtype: torch.dtype): + self.dtype = dtype + + def __repr__(self): + return repr(self.dtype) + + def __eq__(self, other): + return isinstance(other, TorchElementType) and self.dtype == other.dtype + + def ir_type_asm(self) -> str: + dtype = self.dtype + if dtype == torch.float32: + return "f32" + else: + raise ValueError(f"Torch dtype {dtype} cannot be mapped to MLIR type") + + +DefaultElementType = TorchElementType(torch.float32) ############################################################################### # Dimension symbols @@ -127,41 +183,114 @@ class ShapedGrid(Grid, symbolic_shape=symbolic_shape): ############################################################################### +class KernelBufferUsage(Enum): + NONE = 0 + INPUT = 1 + OUTPUT = 2 + TEMPORARY = 3 + + @staticmethod + def _type_name(v) -> str: + if v == KernelBufferUsage.NONE: + return "KernelBuffer" + elif v == KernelBufferUsage.INPUT: + return "InputBuffer" + elif v == KernelBufferUsage.OUTPUT: + return "OutputBuffer" + elif v == KernelBufferUsage.TEMPORARY: + return "TemporaryBuffer" + else: + raise AssertionError(f"uncovered KernelBufferUsage enum ({v})") + + class _KernelBufferMeta(type): """Meta-class for kernel buffers. This lets us specialize with symbolic shape information. """ + element_type: ElementType + usage: KernelBufferUsage + symbolic_shape: Optional[tuple[SymbolDef]] + rank: Optional[int] + def __new__( mcls, name: str, bases, dct, - *, - symbolic_shape: Optional[tuple[SymbolDef]], ): + element_type = dct.get("element_type") or DefaultElementType + dct["element_type"] = element_type + usage = dct.get("usage") or KernelBufferUsage.NONE + dct["usage"] = usage + if "usage" not in dct: + dct["usage"] = KernelBufferUsage.NONE + symbolic_shape = dct.get("symbolic_shape") + dct["symbolic_shape"] = symbolic_shape + dct["rank"] = len(symbolic_shape) if symbolic_shape is not None else None + dct["__qualname__"] = _kernel_buffer_type_repr( + element_type=element_type, usage=usage, symbolic_shape=symbolic_shape + ) new_class = type.__new__(mcls, name, bases, dct) - new_class.symbolic_shape = symbolic_shape - new_class.rank = len(symbolic_shape) if symbolic_shape is not None else None - new_class.__qualname__ = repr(new_class) return new_class - def __class_getitem__( - cls, symbolic_shape: Union[SymbolDef, tuple[SymbolDef]] - ) -> Type["KernelBuffer"]: - if not isinstance(symbolic_shape, tuple): - symbolic_shape = (symbolic_shape,) - return cast(KernelBuffer, _make_shaped_kernel_buffer(cls, symbolic_shape)) - - def __repr__(self): - if self.symbolic_shape: - return f"KernelBuffer[{', '.join(s.name for s in self.symbolic_shape)}]" - else: - return "KernelBuffer" - - -class KernelBuffer(metaclass=_KernelBufferMeta, symbolic_shape=None): + def new_subtype( + cls: Type[SubtypeT], + *, + element_type: Union[NotSetType, ElementType] = NotSet, + symbolic_shape: Union[NotSetType, Optional[tuple[SymbolDef]]] = NotSet, + usage: Union[NotSetType, KernelBufferUsage] = NotSet, + ) -> Type[SubtypeT]: + init_element_type = ( + element_type if element_type is not NotSet else cls.element_type + ) + init_symbolic_shape = ( + symbolic_shape if symbolic_shape is not NotSet else cls.symbolic_shape + ) + init_usage = usage if usage is not NotSet else cls.usage + + class Subtype(cls): + element_type = init_element_type + symbolic_shape = init_symbolic_shape + usage = init_usage + + return Subtype + + def of( + cls: Type[SubtypeT], element_type: Union[Any, ElementType, torch.dtype] + ) -> Type[SubtypeT]: + return cls.new_subtype(element_type=element_type) + + def __repr__(cls): + return _kernel_buffer_type_repr( + element_type=cls.element_type, + usage=cls.usage, + symbolic_shape=cls.symbolic_shape, + ) + + +def _is_kernel_buffer_meta_derived(t: type) -> bool: + return isinstance(t, _KernelBufferMeta) + + +def _kernel_buffer_type_repr( + *, + element_type: ElementType, + usage: KernelBufferUsage, + symbolic_shape: Optional[tuple[SymbolDef]], +) -> str: + root = KernelBufferUsage._type_name(usage) + if symbolic_shape: + stem = f"{root}[{', '.join(s.name for s in symbolic_shape)}]" + else: + stem = f"{root}" + if element_type != DefaultElementType: + stem += f".of({element_type})" + return stem + + +class KernelBuffer(metaclass=_KernelBufferMeta): """Represents a buffer in global memory. Top level kernels always operate on global memory via these @@ -174,8 +303,9 @@ class KernelBuffer(metaclass=_KernelBufferMeta, symbolic_shape=None): is used. """ + usage: ClassVar[KernelBufferUsage] symbolic_shape: ClassVar[Optional[tuple[SymbolDef]]] - rank: int + rank: Optional[int] def __init__(self, tensor: torch.Tensor): assert isinstance(tensor, torch.Tensor), f"Expected Tensor but got {tensor}" @@ -188,6 +318,13 @@ def __init__(self, tensor: torch.Tensor): self._tensor = tensor self.rank = tensor_rank + def __class_getitem__( + cls, symbolic_shape: Union[SymbolDef, tuple[SymbolDef]] + ) -> Type["KernelBuffer"]: + if not isinstance(symbolic_shape, tuple): + symbolic_shape = (symbolic_shape,) + return cast(cls, cls.new_subtype(symbolic_shape=symbolic_shape)) + def __repr__(self): return f"{type(self)}({self._tensor})" @@ -198,10 +335,59 @@ def __getitem__(self, key): return self._tensor.__getitem__(key) -def _make_shaped_kernel_buffer( - cls: Type[KernelBuffer], symbolic_shape: tuple[SymbolDef] -): - class ShapedKernelBuffer(KernelBuffer, symbolic_shape=symbolic_shape): - ... +class InputBuffer(KernelBuffer): + usage = KernelBufferUsage.INPUT + + +class OutputBuffer(KernelBuffer): + usage = KernelBufferUsage.OUTPUT - return ShapedKernelBuffer + +class TemporaryBuffer(KernelBuffer): + usage = KernelBufferUsage.TEMPORARY + + +############################################################################### +# IndexingContext +############################################################################### + + +class IndexingContext: + """The indexing context is responsible handling the binding of indexed + symbols to concrete values. + """ + + def __init__(self): + self.constant_bindings: dict[SymbolDef, int] = {} + + def bind_constant(self, sym: SymbolDef, value: int): + existing = self.constant_bindings.get(sym) + if existing is not None and existing != value: + raise ValueError( + f"Attempt to rebind symbol {sym} to different constant ({value} vs {existing})" + ) + self.constant_bindings[sym] = value + + def get_static_value(self, sym: SymbolDef) -> Optional[int]: + """If the symbol can be resolved to a static value, returns it.""" + return self.constant_bindings.get(sym) + + ##### Context management. + @staticmethod + def current() -> "IndexingContext": + try: + return _tls.indexing_stack[-1] + except (AttributeError, IndexError): + raise AssertionError("no IndexingContext is active") + + def __enter__(self) -> "IndexingContext": + try: + stack = _tls.indexing_stack + except AttributeError: + stack = [] + _tls.indexing_stack = stack + stack.append(self) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + _tls.indexing_stack.pop() diff --git a/python/shark_turbine/kernel/_support/typing.py b/python/shark_turbine/kernel/_support/typing.py deleted file mode 100644 index b5cc7b7c1..000000000 --- a/python/shark_turbine/kernel/_support/typing.py +++ /dev/null @@ -1,2 +0,0 @@ -from typing import Callable -import inpsect diff --git a/python/shark_turbine/kernel/compiler/builder.py b/python/shark_turbine/kernel/compiler/builder.py new file mode 100644 index 000000000..e1ed0d9a2 --- /dev/null +++ b/python/shark_turbine/kernel/compiler/builder.py @@ -0,0 +1,26 @@ +from typing import Optional + +from .ir import ( + Context, + Location, + Operation, + builtin_d, +) + + +class ModuleBuilder: + def __init__( + self, + *, + context: Optional[Context] = None, + module_op: Optional[Operation] = None + ): + if module_op: + self.module_op = module_op + self.body_block = module_op.regions[0].blocks[0] + else: + if not context: + context = Context() + self.module_op = builtin_d.ModuleOp(loc=Location.unknown(context)) + self.body_block = self.module_op.body + self.context = self.module_op.context diff --git a/python/shark_turbine/kernel/compiler/ir.py b/python/shark_turbine/kernel/compiler/ir.py new file mode 100644 index 000000000..facfc22bf --- /dev/null +++ b/python/shark_turbine/kernel/compiler/ir.py @@ -0,0 +1,17 @@ +from iree.compiler.ir import ( + Context, + F32Type, + FunctionType, + IndexType, + InsertionPoint, + Location, + Operation, + MemRefType, + Type as IrType, + Value, +) + +from iree.compiler.dialects import ( + builtin as builtin_d, + func as func_d, +) diff --git a/python/shark_turbine/kernel/compiler/vector_codegen.py b/python/shark_turbine/kernel/compiler/vector_codegen.py new file mode 100644 index 000000000..0f23d5975 --- /dev/null +++ b/python/shark_turbine/kernel/compiler/vector_codegen.py @@ -0,0 +1,149 @@ +from typing import Type, Optional, Sequence, Union, cast + +from dataclasses import dataclass + +import torch.fx as fx + +from .._support.indexing import ( + Grid, + IndexingContext, + KernelBuffer, + SymbolDef, + _is_kernel_buffer_meta_derived, +) + +from .builder import ( + ModuleBuilder, +) + +from .ir import ( + FunctionType, + IndexType, + InsertionPoint, + IrType, + Location, + Value, + func_d, +) + + +ArgTypeUnion = Union[SymbolDef, Type[KernelBuffer]] + + +@dataclass +class ArgMeta: + name: Optional[str] = None + node: Optional[fx.Node] = None + + +class Signature: + """Represents a function signature. + + Signatures can carry: + - Input, output and temporary KernelBuffers + - SymbolDef + + For now, if we enounter any of these, we emit them in declaration order. + We need a better convention than this (i.e. inputs, then outputs, them symbols, them temporaries). + """ + + def __init__(self): + self.args: list[tuple[ArgMeta, ArgTypeUnion]] = [] + + def add_kernel_buffer( + self, kb: Type[KernelBuffer], *, meta: Optional[ArgMeta] = None + ): + self.args.append((meta if meta is not None else ArgMeta(), kb)) + + def add_symbol(self, sym: SymbolDef, *, meta: Optional[ArgMeta] = None): + self.args.append((meta if meta is not None else ArgMeta(), sym)) + + @property + def arg_metas(self) -> Sequence[ArgMeta]: + return (meta for meta, _ in self.args) + + def as_function_type(self) -> FunctionType: + idx_c = IndexingContext.current() + + def sym_to_dim_asm(s: SymbolDef) -> str: + static_value = idx_c.get_static_value(s) + return "?" if static_value is None else str(static_value) + + def as_mlir_type(t: ArgTypeUnion) -> FunctionType: + if isinstance(t, SymbolDef): + return IndexType.get() + elif _is_kernel_buffer_meta_derived(t): + kb_t = t # type: KernelBuffer + element_type_asm = kb_t.element_type.ir_type_asm() + symbolic_shape = kb_t.symbolic_shape + if symbolic_shape is not None: + shape_asm = "x".join(sym_to_dim_asm(s) for s in kb_t.symbolic_shape) + spec_asm = f"{shape_asm}x{element_type_asm}" + else: + # Unranked. Not well supported, but for completeness. + spec_asm = element_type_asm + memref_asm = f"memref<{spec_asm}>" + return IrType.parse(memref_asm) + + inputs = [as_mlir_type(arg) for _, arg in self.args] + return FunctionType.get(inputs, []) + + def add_from_graph_placeholders(self, graph: fx.Graph): + for node in graph.nodes: + if node.op != "placeholder": + continue + t = node.type + meta = ArgMeta(name=node.target, node=node) + if _is_kernel_buffer_meta_derived(t): + self.add_kernel_buffer(t, meta=meta) + elif issubclass(t, SymbolDef): + self.add_symbol(t, meta=meta) + + def __repr__(self): + parts = [] + for meta, arg in self.args: + part = repr(arg) + if meta.name: + part = f"{meta.name}: {part}" + parts.append(part) + return f"Signature({', '.join(parts)})" + + +class ThreadEmitter: + """Emits a 'thread function' as a `func` with a signature derived from the gm.""" + + def __init__(self, mb: ModuleBuilder, grid: Grid, sig: Signature): + self.nv_map: dict[fx.Node, Value] = {} + + # TODO: Infer a location from graph. + with InsertionPoint(mb.body_block), Location.unknown(): + ftype = sig.as_function_type() + func_op = func_d.FuncOp("kernel", ftype) + self.func_op = func_op + arg_locs = [ + ( + Location.name(meta.name) + if meta.name is not None + else Location.unknown() + ) + for meta in sig.arg_metas + ] + self.entry_block = func_op.add_entry_block(arg_locs) + + # Bind all inputs in the node-value map. + for block_arg, meta in zip(self.entry_block.arguments, sig.arg_metas): + assert ( + meta.node + ), "expected all signature args to have an associated node" + self.nv_map[meta.node] = block_arg + self.ip = InsertionPoint(self.entry_block) + + def emit_node(self, node: fx.Node): + ... + + def emit_graph(self, graph: fx.Graph): + ... + + def finish(self): + with self.ip, Location.unknown(): + func_d.ReturnOp([]) diff --git a/python/shark_turbine/kernel/gen/thread.py b/python/shark_turbine/kernel/gen/thread.py index 3e1758459..ded5e765e 100644 --- a/python/shark_turbine/kernel/gen/thread.py +++ b/python/shark_turbine/kernel/gen/thread.py @@ -29,7 +29,7 @@ def thread(*symbolic_shape: SymbolDef): GridType = Grid[symbolic_shape] - def decorator(f: Optional[TCallable] = None) -> TCallable: + def decorator(f: Optional[TCallable] = None) -> "UnconfiguredThread[TCallable]": # Eagerly capture the trace and attach it to the wrapped function. tracer = KernelTracer() with CompiledContext(tracer) as context: @@ -49,7 +49,7 @@ def __init__( wrapped_f: TCallable, trace: CapturedTrace, ): - self._grid_type = grid_type + self.grid_type = grid_type self._name = name self._wrapped_f = wrapped_f self._trace = trace @@ -58,7 +58,7 @@ def __getitem__(self, grid: Union[int, tuple[int]]) -> TCallable: if not isinstance(grid, tuple): grid = (grid,) assert isinstance(grid, tuple) and all(isinstance(i, int) for i in grid) - grid = self._grid_type(*grid) + grid = self.grid_type(*grid) return cast( TCallable, LaunchableThread(grid, self._name, self._wrapped_f, self._trace) ) @@ -73,6 +73,7 @@ def __init__( ): super().__init__(eager_function) self.grid = grid + self.grid_type = type(grid) self._name = name self._trace = trace self._sig = inspect.signature(eager_function) diff --git a/python/shark_turbine/kernel/lang/__init__.py b/python/shark_turbine/kernel/lang/__init__.py index 87ead9569..03cd2c402 100644 --- a/python/shark_turbine/kernel/lang/__init__.py +++ b/python/shark_turbine/kernel/lang/__init__.py @@ -3,7 +3,10 @@ # Include publics from the _support library. from .._support.indexing import ( Grid, + InputBuffer, KernelBuffer, + OutputBuffer, SymbolDef, + TemporaryBuffer, sym, -) \ No newline at end of file +) diff --git a/tests/kernel/indexing_test.py b/tests/kernel/indexing_test.py index 13daac67c..b73a11a47 100644 --- a/tests/kernel/indexing_test.py +++ b/tests/kernel/indexing_test.py @@ -61,6 +61,10 @@ def testKernelBufferInstance(self): self.assertEqual(1, kb.rank) self.assertEqual((M,), kb.symbolic_shape) + def testUsageAndElementTypeInstance(self): + T = InputBuffer[M].of(torch.float16) + self.assertEqual("InputBuffer[M].of(torch.float16)", repr(T)) + if __name__ == "__main__": unittest.main() diff --git a/tests/kernel/vector_codegen_test.py b/tests/kernel/vector_codegen_test.py new file mode 100644 index 000000000..d41083279 --- /dev/null +++ b/tests/kernel/vector_codegen_test.py @@ -0,0 +1,47 @@ +import logging +import unittest + +import torch + +import shark_turbine.kernel as tk + +from shark_turbine.kernel.compiler import ( + builder, + vector_codegen, +) +from shark_turbine.kernel._support import ( + indexing, +) + +M = tk.lang.sym.M +K = tk.lang.sym.K + + +class Test(unittest.TestCase): + # This test is using the compiler "the hard way" until we have all of the + # API layering in place. + def testIotaFx(self): + @tk.gen.thread(M) + def iota_kernel(out: tk.lang.OutputBuffer[M]): + i = tk.lang.program_id(0) + out[i] = i + + gm = iota_kernel._trace.gm + print(gm.graph) + mb = builder.ModuleBuilder() + with indexing.IndexingContext() as idxc: + idxc.bind_constant(M, 17) + + sig = vector_codegen.Signature() + sig.add_from_graph_placeholders(gm.graph) + print(sig) + emitter = vector_codegen.ThreadEmitter(mb, iota_kernel.grid_type, sig) + emitter.emit_graph(gm.graph) + emitter.finish() + print(mb.module_op.get_asm()) + mb.module_op.verify() + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() From dd8a9e61bc728bce9ce45643a24bfec46afe2233 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Sun, 3 Dec 2023 22:57:10 -0800 Subject: [PATCH 3/3] Add grid to kernel. --- .../kernel/compiler/vector_codegen.py | 16 +++++++++++++--- tests/kernel/vector_codegen_test.py | 1 + 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/python/shark_turbine/kernel/compiler/vector_codegen.py b/python/shark_turbine/kernel/compiler/vector_codegen.py index 0f23d5975..175a73978 100644 --- a/python/shark_turbine/kernel/compiler/vector_codegen.py +++ b/python/shark_turbine/kernel/compiler/vector_codegen.py @@ -34,6 +34,7 @@ class ArgMeta: name: Optional[str] = None node: Optional[fx.Node] = None + grid_index: Optional[int] = None class Signature: @@ -88,6 +89,11 @@ def as_mlir_type(t: ArgTypeUnion) -> FunctionType: inputs = [as_mlir_type(arg) for _, arg in self.args] return FunctionType.get(inputs, []) + def add_grid(self, grid: Type[Grid]): + assert grid.symbolic_shape, "code emission requires a symbolically shaped grid" + for index, s in enumerate(grid.symbolic_shape): + self.add_symbol(s, meta=ArgMeta(grid_index=index, name=f"grid{index}")) + def add_from_graph_placeholders(self, graph: fx.Graph): for node in graph.nodes: if node.op != "placeholder": @@ -114,6 +120,7 @@ class ThreadEmitter: def __init__(self, mb: ModuleBuilder, grid: Grid, sig: Signature): self.nv_map: dict[fx.Node, Value] = {} + self.grid_index_map: list[Value] = [None] * grid.rank # TODO: Infer a location from graph. with InsertionPoint(mb.body_block), Location.unknown(): @@ -133,9 +140,12 @@ def __init__(self, mb: ModuleBuilder, grid: Grid, sig: Signature): # Bind all inputs in the node-value map. for block_arg, meta in zip(self.entry_block.arguments, sig.arg_metas): assert ( - meta.node - ), "expected all signature args to have an associated node" - self.nv_map[meta.node] = block_arg + meta.node or meta.grid_index is not None + ), "expected all signature args to have an associated node or grid_index" + if meta.node: + self.nv_map[meta.node] = block_arg + elif meta.grid_index is not None: + self.grid_index_map[meta.grid_index] = block_arg self.ip = InsertionPoint(self.entry_block) def emit_node(self, node: fx.Node): diff --git a/tests/kernel/vector_codegen_test.py b/tests/kernel/vector_codegen_test.py index d41083279..866f3ca99 100644 --- a/tests/kernel/vector_codegen_test.py +++ b/tests/kernel/vector_codegen_test.py @@ -34,6 +34,7 @@ def iota_kernel(out: tk.lang.OutputBuffer[M]): sig = vector_codegen.Signature() sig.add_from_graph_placeholders(gm.graph) + sig.add_grid(iota_kernel.grid_type) print(sig) emitter = vector_codegen.ThreadEmitter(mb, iota_kernel.grid_type, sig) emitter.emit_graph(gm.graph)