diff --git a/python/shark_turbine/kernel/_support/indexing.py b/python/shark_turbine/kernel/_support/indexing.py new file mode 100644 index 000000000..a60ac78d6 --- /dev/null +++ b/python/shark_turbine/kernel/_support/indexing.py @@ -0,0 +1,393 @@ +from typing import Any, ClassVar, Optional, Type, TypeVar, Union, cast + +from abc import ABC, abstractmethod +from enum import Enum +import threading + +import torch + +__all__ = [ + "KernelBuffer", + "Grid", + "InputBuffer", + "OutputBuffer", + "SymbolDef", + "TemporaryBuffer", + "sym", +] + +_tls = threading.local() + + +class NotSetType: + ... + + +NotSet = NotSetType() + +SubtypeT = TypeVar("SubtypeT") + +############################################################################### +# ElementType +############################################################################### + + +class ElementType(ABC): + @staticmethod + def cast(something) -> "ElementType": + if isinstance(something, torch.dtyp): + return TorchElementType(something) + else: + raise TypeError( + f"Cannot convert {something} (of type {type(something)}) to an element type" + ) + + @abstractmethod + def ir_type_asm(self) -> str: + ... + + +class TorchElementType(ElementType): + def __init__(self, dtype: torch.dtype): + self.dtype = dtype + + def __repr__(self): + return repr(self.dtype) + + def __eq__(self, other): + return isinstance(other, TorchElementType) and self.dtype == other.dtype + + def ir_type_asm(self) -> str: + dtype = self.dtype + if dtype == torch.float32: + return "f32" + else: + raise ValueError(f"Torch dtype {dtype} cannot be mapped to MLIR type") + + +DefaultElementType = TorchElementType(torch.float32) + +############################################################################### +# Dimension symbols +############################################################################### + + +class SymbolDef: + """Represents a named symbol representing a dimension in a shape.""" + + ALL_SYMBOLS: ClassVar[dict[str, "SymbolDef"]] = dict() + name: str + + def __new__(cls, name: str): + existing = cls.ALL_SYMBOLS.get(name) + if existing is not None: + return existing + new = super().__new__(cls) + new.name = name + cls.ALL_SYMBOLS[name] = new + return new + + def __repr__(self): + return f"Symbol({self.name})" + + @classmethod + def create_expando(cls): + """Create an expando class that creates unique symbols based on attr access.""" + + class Expando: + def __getattr__(self, n): + return cls(n) + + return Expando() + + +sym = SymbolDef.create_expando() + +############################################################################### +# Grid +############################################################################### + + +class _GridMeta(type): + """Meta-class for a symbolically shaped grid.""" + + def __new__( + mcls, + name: str, + bases, + dct, + *, + symbolic_shape: Optional[tuple[SymbolDef]], + ): + new_class = type.__new__(mcls, name, bases, dct) + new_class.symbolic_shape = symbolic_shape + new_class.rank = len(symbolic_shape) if symbolic_shape is not None else None + new_class.__qualname__ = repr(new_class) + return new_class + + def __class_getitem__( + cls, symbolic_shape: Union[SymbolDef, tuple[SymbolDef]] + ) -> Type["Grid"]: + if not isinstance(symbolic_shape, tuple): + symbolic_shape = (symbolic_shape,) + return cast(Grid, _make_shaped_grid(cls, symbolic_shape)) + + def __repr__(self): + if self.symbolic_shape: + return f"Grid[{', '.join(s.name for s in self.symbolic_shape)}]" + else: + return "Grid" + + +class Grid(metaclass=_GridMeta, symbolic_shape=None): + """Grid with bounding symbolic shape information in the type.""" + + symbolic_shape: ClassVar[Optional[tuple[SymbolDef]]] + rank: int + + def __init__(self, *dims: int): + rank = len(dims) + if self.symbolic_shape is not None: + if rank != len(self.symbolic_shape): + raise ValueError( + f"Cannot create {type(self)}({', '.join(str(i) for i in dims)}): mismatched symbolic rank" + ) + + self.dims = dims + # Shadow the type rank with the actual, which makes it concrete + # for the generic case. + self.rank = rank + + def __repr__(self): + return f"{repr(type(self))}({', '.join(str(i) for i in self.dims)})" + + def __getitem__(self, index: int) -> int: + return self.dims[index] + + def __len__(self) -> int: + return len(self.dims) + + def __iter__(self): + return iter(self.dims) + + +def _make_shaped_grid(cls: Type[Grid], symbolic_shape: tuple[SymbolDef]): + class ShapedGrid(Grid, symbolic_shape=symbolic_shape): + ... + + return ShapedGrid + + +############################################################################### +# KernelBuffer +############################################################################### + + +class KernelBufferUsage(Enum): + NONE = 0 + INPUT = 1 + OUTPUT = 2 + TEMPORARY = 3 + + @staticmethod + def _type_name(v) -> str: + if v == KernelBufferUsage.NONE: + return "KernelBuffer" + elif v == KernelBufferUsage.INPUT: + return "InputBuffer" + elif v == KernelBufferUsage.OUTPUT: + return "OutputBuffer" + elif v == KernelBufferUsage.TEMPORARY: + return "TemporaryBuffer" + else: + raise AssertionError(f"uncovered KernelBufferUsage enum ({v})") + + +class _KernelBufferMeta(type): + """Meta-class for kernel buffers. + + This lets us specialize with symbolic shape information. + """ + + element_type: ElementType + usage: KernelBufferUsage + symbolic_shape: Optional[tuple[SymbolDef]] + rank: Optional[int] + + def __new__( + mcls, + name: str, + bases, + dct, + ): + element_type = dct.get("element_type") or DefaultElementType + dct["element_type"] = element_type + usage = dct.get("usage") or KernelBufferUsage.NONE + dct["usage"] = usage + if "usage" not in dct: + dct["usage"] = KernelBufferUsage.NONE + symbolic_shape = dct.get("symbolic_shape") + dct["symbolic_shape"] = symbolic_shape + dct["rank"] = len(symbolic_shape) if symbolic_shape is not None else None + dct["__qualname__"] = _kernel_buffer_type_repr( + element_type=element_type, usage=usage, symbolic_shape=symbolic_shape + ) + new_class = type.__new__(mcls, name, bases, dct) + return new_class + + def new_subtype( + cls: Type[SubtypeT], + *, + element_type: Union[NotSetType, ElementType] = NotSet, + symbolic_shape: Union[NotSetType, Optional[tuple[SymbolDef]]] = NotSet, + usage: Union[NotSetType, KernelBufferUsage] = NotSet, + ) -> Type[SubtypeT]: + init_element_type = ( + element_type if element_type is not NotSet else cls.element_type + ) + init_symbolic_shape = ( + symbolic_shape if symbolic_shape is not NotSet else cls.symbolic_shape + ) + init_usage = usage if usage is not NotSet else cls.usage + + class Subtype(cls): + element_type = init_element_type + symbolic_shape = init_symbolic_shape + usage = init_usage + + return Subtype + + def of( + cls: Type[SubtypeT], element_type: Union[Any, ElementType, torch.dtype] + ) -> Type[SubtypeT]: + return cls.new_subtype(element_type=element_type) + + def __repr__(cls): + return _kernel_buffer_type_repr( + element_type=cls.element_type, + usage=cls.usage, + symbolic_shape=cls.symbolic_shape, + ) + + +def _is_kernel_buffer_meta_derived(t: type) -> bool: + return isinstance(t, _KernelBufferMeta) + + +def _kernel_buffer_type_repr( + *, + element_type: ElementType, + usage: KernelBufferUsage, + symbolic_shape: Optional[tuple[SymbolDef]], +) -> str: + root = KernelBufferUsage._type_name(usage) + if symbolic_shape: + stem = f"{root}[{', '.join(s.name for s in symbolic_shape)}]" + else: + stem = f"{root}" + if element_type != DefaultElementType: + stem += f".of({element_type})" + return stem + + +class KernelBuffer(metaclass=_KernelBufferMeta): + """Represents a buffer in global memory. + + Top level kernels always operate on global memory via these + buffers, and the primary operations that can be performed on + them are loads/stores and DMAs to some form of compute + capable local buffer. + + When executing eagerly, these are backed by a normal torch + Tensor. When compiling, an appropriate duck-typed proxy + is used. + """ + + usage: ClassVar[KernelBufferUsage] + symbolic_shape: ClassVar[Optional[tuple[SymbolDef]]] + rank: Optional[int] + + def __init__(self, tensor: torch.Tensor): + assert isinstance(tensor, torch.Tensor), f"Expected Tensor but got {tensor}" + type_rank = type(self).rank + tensor_rank = len(tensor.shape) + if type_rank is not None and type_rank != tensor_rank: + raise ValueError( + f"Cannot create {type(self)}(tensor({tensor.shape})): mismatched symbolic rank" + ) + self._tensor = tensor + self.rank = tensor_rank + + def __class_getitem__( + cls, symbolic_shape: Union[SymbolDef, tuple[SymbolDef]] + ) -> Type["KernelBuffer"]: + if not isinstance(symbolic_shape, tuple): + symbolic_shape = (symbolic_shape,) + return cast(cls, cls.new_subtype(symbolic_shape=symbolic_shape)) + + def __repr__(self): + return f"{type(self)}({self._tensor})" + + def __setitem__(self, key, item): + self._tensor.__setitem__(key, item) + + def __getitem__(self, key): + return self._tensor.__getitem__(key) + + +class InputBuffer(KernelBuffer): + usage = KernelBufferUsage.INPUT + + +class OutputBuffer(KernelBuffer): + usage = KernelBufferUsage.OUTPUT + + +class TemporaryBuffer(KernelBuffer): + usage = KernelBufferUsage.TEMPORARY + + +############################################################################### +# IndexingContext +############################################################################### + + +class IndexingContext: + """The indexing context is responsible handling the binding of indexed + symbols to concrete values. + """ + + def __init__(self): + self.constant_bindings: dict[SymbolDef, int] = {} + + def bind_constant(self, sym: SymbolDef, value: int): + existing = self.constant_bindings.get(sym) + if existing is not None and existing != value: + raise ValueError( + f"Attempt to rebind symbol {sym} to different constant ({value} vs {existing})" + ) + self.constant_bindings[sym] = value + + def get_static_value(self, sym: SymbolDef) -> Optional[int]: + """If the symbol can be resolved to a static value, returns it.""" + return self.constant_bindings.get(sym) + + ##### Context management. + @staticmethod + def current() -> "IndexingContext": + try: + return _tls.indexing_stack[-1] + except (AttributeError, IndexError): + raise AssertionError("no IndexingContext is active") + + def __enter__(self) -> "IndexingContext": + try: + stack = _tls.indexing_stack + except AttributeError: + stack = [] + _tls.indexing_stack = stack + stack.append(self) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + _tls.indexing_stack.pop() diff --git a/python/shark_turbine/kernel/_support/tracing.py b/python/shark_turbine/kernel/_support/tracing.py index 2f195e0f9..2715a6107 100644 --- a/python/shark_turbine/kernel/_support/tracing.py +++ b/python/shark_turbine/kernel/_support/tracing.py @@ -7,7 +7,7 @@ import torch.fx as fx -from ..lang.types import ( +from .indexing import ( KernelBuffer, ) diff --git a/python/shark_turbine/kernel/_support/typing.py b/python/shark_turbine/kernel/_support/typing.py deleted file mode 100644 index b5cc7b7c1..000000000 --- a/python/shark_turbine/kernel/_support/typing.py +++ /dev/null @@ -1,2 +0,0 @@ -from typing import Callable -import inpsect diff --git a/python/shark_turbine/kernel/compiler/builder.py b/python/shark_turbine/kernel/compiler/builder.py new file mode 100644 index 000000000..e1ed0d9a2 --- /dev/null +++ b/python/shark_turbine/kernel/compiler/builder.py @@ -0,0 +1,26 @@ +from typing import Optional + +from .ir import ( + Context, + Location, + Operation, + builtin_d, +) + + +class ModuleBuilder: + def __init__( + self, + *, + context: Optional[Context] = None, + module_op: Optional[Operation] = None + ): + if module_op: + self.module_op = module_op + self.body_block = module_op.regions[0].blocks[0] + else: + if not context: + context = Context() + self.module_op = builtin_d.ModuleOp(loc=Location.unknown(context)) + self.body_block = self.module_op.body + self.context = self.module_op.context diff --git a/python/shark_turbine/kernel/compiler/ir.py b/python/shark_turbine/kernel/compiler/ir.py new file mode 100644 index 000000000..facfc22bf --- /dev/null +++ b/python/shark_turbine/kernel/compiler/ir.py @@ -0,0 +1,17 @@ +from iree.compiler.ir import ( + Context, + F32Type, + FunctionType, + IndexType, + InsertionPoint, + Location, + Operation, + MemRefType, + Type as IrType, + Value, +) + +from iree.compiler.dialects import ( + builtin as builtin_d, + func as func_d, +) diff --git a/python/shark_turbine/kernel/compiler/vector_codegen.py b/python/shark_turbine/kernel/compiler/vector_codegen.py new file mode 100644 index 000000000..175a73978 --- /dev/null +++ b/python/shark_turbine/kernel/compiler/vector_codegen.py @@ -0,0 +1,159 @@ +from typing import Type, Optional, Sequence, Union, cast + +from dataclasses import dataclass + +import torch.fx as fx + +from .._support.indexing import ( + Grid, + IndexingContext, + KernelBuffer, + SymbolDef, + _is_kernel_buffer_meta_derived, +) + +from .builder import ( + ModuleBuilder, +) + +from .ir import ( + FunctionType, + IndexType, + InsertionPoint, + IrType, + Location, + Value, + func_d, +) + + +ArgTypeUnion = Union[SymbolDef, Type[KernelBuffer]] + + +@dataclass +class ArgMeta: + name: Optional[str] = None + node: Optional[fx.Node] = None + grid_index: Optional[int] = None + + +class Signature: + """Represents a function signature. + + Signatures can carry: + - Input, output and temporary KernelBuffers + - SymbolDef + + For now, if we enounter any of these, we emit them in declaration order. + We need a better convention than this (i.e. inputs, then outputs, them symbols, them temporaries). + """ + + def __init__(self): + self.args: list[tuple[ArgMeta, ArgTypeUnion]] = [] + + def add_kernel_buffer( + self, kb: Type[KernelBuffer], *, meta: Optional[ArgMeta] = None + ): + self.args.append((meta if meta is not None else ArgMeta(), kb)) + + def add_symbol(self, sym: SymbolDef, *, meta: Optional[ArgMeta] = None): + self.args.append((meta if meta is not None else ArgMeta(), sym)) + + @property + def arg_metas(self) -> Sequence[ArgMeta]: + return (meta for meta, _ in self.args) + + def as_function_type(self) -> FunctionType: + idx_c = IndexingContext.current() + + def sym_to_dim_asm(s: SymbolDef) -> str: + static_value = idx_c.get_static_value(s) + return "?" if static_value is None else str(static_value) + + def as_mlir_type(t: ArgTypeUnion) -> FunctionType: + if isinstance(t, SymbolDef): + return IndexType.get() + elif _is_kernel_buffer_meta_derived(t): + kb_t = t # type: KernelBuffer + element_type_asm = kb_t.element_type.ir_type_asm() + symbolic_shape = kb_t.symbolic_shape + if symbolic_shape is not None: + shape_asm = "x".join(sym_to_dim_asm(s) for s in kb_t.symbolic_shape) + spec_asm = f"{shape_asm}x{element_type_asm}" + else: + # Unranked. Not well supported, but for completeness. + spec_asm = element_type_asm + memref_asm = f"memref<{spec_asm}>" + return IrType.parse(memref_asm) + + inputs = [as_mlir_type(arg) for _, arg in self.args] + return FunctionType.get(inputs, []) + + def add_grid(self, grid: Type[Grid]): + assert grid.symbolic_shape, "code emission requires a symbolically shaped grid" + for index, s in enumerate(grid.symbolic_shape): + self.add_symbol(s, meta=ArgMeta(grid_index=index, name=f"grid{index}")) + + def add_from_graph_placeholders(self, graph: fx.Graph): + for node in graph.nodes: + if node.op != "placeholder": + continue + t = node.type + meta = ArgMeta(name=node.target, node=node) + if _is_kernel_buffer_meta_derived(t): + self.add_kernel_buffer(t, meta=meta) + elif issubclass(t, SymbolDef): + self.add_symbol(t, meta=meta) + + def __repr__(self): + parts = [] + for meta, arg in self.args: + part = repr(arg) + if meta.name: + part = f"{meta.name}: {part}" + parts.append(part) + return f"Signature({', '.join(parts)})" + + +class ThreadEmitter: + """Emits a 'thread function' as a `func` with a signature derived from the gm.""" + + def __init__(self, mb: ModuleBuilder, grid: Grid, sig: Signature): + self.nv_map: dict[fx.Node, Value] = {} + self.grid_index_map: list[Value] = [None] * grid.rank + + # TODO: Infer a location from graph. + with InsertionPoint(mb.body_block), Location.unknown(): + ftype = sig.as_function_type() + func_op = func_d.FuncOp("kernel", ftype) + self.func_op = func_op + arg_locs = [ + ( + Location.name(meta.name) + if meta.name is not None + else Location.unknown() + ) + for meta in sig.arg_metas + ] + self.entry_block = func_op.add_entry_block(arg_locs) + + # Bind all inputs in the node-value map. + for block_arg, meta in zip(self.entry_block.arguments, sig.arg_metas): + assert ( + meta.node or meta.grid_index is not None + ), "expected all signature args to have an associated node or grid_index" + if meta.node: + self.nv_map[meta.node] = block_arg + elif meta.grid_index is not None: + self.grid_index_map[meta.grid_index] = block_arg + self.ip = InsertionPoint(self.entry_block) + + def emit_node(self, node: fx.Node): + ... + + def emit_graph(self, graph: fx.Graph): + ... + + def finish(self): + with self.ip, Location.unknown(): + func_d.ReturnOp([]) diff --git a/python/shark_turbine/kernel/gen/thread.py b/python/shark_turbine/kernel/gen/thread.py index 3e1758459..ded5e765e 100644 --- a/python/shark_turbine/kernel/gen/thread.py +++ b/python/shark_turbine/kernel/gen/thread.py @@ -29,7 +29,7 @@ def thread(*symbolic_shape: SymbolDef): GridType = Grid[symbolic_shape] - def decorator(f: Optional[TCallable] = None) -> TCallable: + def decorator(f: Optional[TCallable] = None) -> "UnconfiguredThread[TCallable]": # Eagerly capture the trace and attach it to the wrapped function. tracer = KernelTracer() with CompiledContext(tracer) as context: @@ -49,7 +49,7 @@ def __init__( wrapped_f: TCallable, trace: CapturedTrace, ): - self._grid_type = grid_type + self.grid_type = grid_type self._name = name self._wrapped_f = wrapped_f self._trace = trace @@ -58,7 +58,7 @@ def __getitem__(self, grid: Union[int, tuple[int]]) -> TCallable: if not isinstance(grid, tuple): grid = (grid,) assert isinstance(grid, tuple) and all(isinstance(i, int) for i in grid) - grid = self._grid_type(*grid) + grid = self.grid_type(*grid) return cast( TCallable, LaunchableThread(grid, self._name, self._wrapped_f, self._trace) ) @@ -73,6 +73,7 @@ def __init__( ): super().__init__(eager_function) self.grid = grid + self.grid_type = type(grid) self._name = name self._trace = trace self._sig = inspect.signature(eager_function) diff --git a/python/shark_turbine/kernel/lang/__init__.py b/python/shark_turbine/kernel/lang/__init__.py index 332381984..03cd2c402 100644 --- a/python/shark_turbine/kernel/lang/__init__.py +++ b/python/shark_turbine/kernel/lang/__init__.py @@ -1,2 +1,12 @@ from .prims import * -from .types import * + +# Include publics from the _support library. +from .._support.indexing import ( + Grid, + InputBuffer, + KernelBuffer, + OutputBuffer, + SymbolDef, + TemporaryBuffer, + sym, +) diff --git a/python/shark_turbine/kernel/lang/types.py b/python/shark_turbine/kernel/lang/types.py deleted file mode 100644 index 650e8a526..000000000 --- a/python/shark_turbine/kernel/lang/types.py +++ /dev/null @@ -1,207 +0,0 @@ -from typing import ClassVar, Optional, Type, TypeVar, Union, cast -import torch - -__all__ = [ - "KernelBuffer", - "Grid", - "SymbolDef", - "sym", -] - -Grid = tuple[int, ...] - - -############################################################################### -# Dimension symbols -############################################################################### - - -class SymbolDef: - """Represents a named symbol representing a dimension in a shape.""" - - ALL_SYMBOLS: ClassVar[dict[str, "SymbolDef"]] = dict() - name: str - - def __new__(cls, name: str): - existing = cls.ALL_SYMBOLS.get(name) - if existing is not None: - return existing - new = super().__new__(cls) - new.name = name - cls.ALL_SYMBOLS[name] = new - return new - - def __repr__(self): - return f"Symbol({self.name})" - - @classmethod - def create_expando(cls): - """Create an expando class that creates unique symbols based on attr access.""" - - class Expando: - def __getattr__(self, n): - return cls(n) - - return Expando() - - -sym = SymbolDef.create_expando() - -############################################################################### -# Grid -############################################################################### - - -class _GridMeta(type): - """Meta-class for a symbolically shaped grid.""" - - def __new__( - mcls, - name: str, - bases, - dct, - *, - symbolic_shape: Optional[tuple[SymbolDef]], - ): - new_class = type.__new__(mcls, name, bases, dct) - new_class.symbolic_shape = symbolic_shape - new_class.rank = len(symbolic_shape) if symbolic_shape is not None else None - new_class.__qualname__ = repr(new_class) - return new_class - - def __class_getitem__( - cls, symbolic_shape: Union[SymbolDef, tuple[SymbolDef]] - ) -> Type["Grid"]: - if not isinstance(symbolic_shape, tuple): - symbolic_shape = (symbolic_shape,) - return cast(Grid, _make_shaped_grid(cls, symbolic_shape)) - - def __repr__(self): - if self.symbolic_shape: - return f"Grid[{', '.join(s.name for s in self.symbolic_shape)}]" - else: - return "Grid" - - -class Grid(metaclass=_GridMeta, symbolic_shape=None): - """Grid with bounding symbolic shape information in the type.""" - - symbolic_shape: ClassVar[Optional[tuple[SymbolDef]]] - rank: int - - def __init__(self, *dims: int): - rank = len(dims) - if self.symbolic_shape is not None: - if rank != len(self.symbolic_shape): - raise ValueError( - f"Cannot create {type(self)}({', '.join(str(i) for i in dims)}): mismatched symbolic rank" - ) - - self.dims = dims - # Shadow the type rank with the actual, which makes it concrete - # for the generic case. - self.rank = rank - - def __repr__(self): - return f"{repr(type(self))}({', '.join(str(i) for i in self.dims)})" - - def __getitem__(self, index: int) -> int: - return self.dims[index] - - def __len__(self) -> int: - return len(self.dims) - - def __iter__(self): - return iter(self.dims) - - -def _make_shaped_grid(cls: Type[Grid], symbolic_shape: tuple[SymbolDef]): - class ShapedGrid(Grid, symbolic_shape=symbolic_shape): - ... - - return ShapedGrid - - -############################################################################### -# KernelBuffer -############################################################################### - - -class _KernelBufferMeta(type): - """Meta-class for kernel buffers. - - This lets us specialize with symbolic shape information. - """ - - def __new__( - mcls, - name: str, - bases, - dct, - *, - symbolic_shape: Optional[tuple[SymbolDef]], - ): - new_class = type.__new__(mcls, name, bases, dct) - new_class.symbolic_shape = symbolic_shape - new_class.rank = len(symbolic_shape) if symbolic_shape is not None else None - new_class.__qualname__ = repr(new_class) - return new_class - - def __class_getitem__( - cls, symbolic_shape: Union[SymbolDef, tuple[SymbolDef]] - ) -> Type["KernelBuffer"]: - if not isinstance(symbolic_shape, tuple): - symbolic_shape = (symbolic_shape,) - return cast(KernelBuffer, _make_shaped_kernel_buffer(cls, symbolic_shape)) - - def __repr__(self): - if self.symbolic_shape: - return f"KernelBuffer[{', '.join(s.name for s in self.symbolic_shape)}]" - else: - return "KernelBuffer" - - -class KernelBuffer(metaclass=_KernelBufferMeta, symbolic_shape=None): - """Represents a buffer in global memory. - - Top level kernels always operate on global memory via these - buffers, and the primary operations that can be performed on - them are loads/stores and DMAs to some form of compute - capable local buffer. - - When executing eagerly, these are backed by a normal torch - Tensor. When compiling, an appropriate duck-typed proxy - is used. - """ - - symbolic_shape: ClassVar[Optional[tuple[SymbolDef]]] - rank: int - - def __init__(self, tensor: torch.Tensor): - assert isinstance(tensor, torch.Tensor), f"Expected Tensor but got {tensor}" - type_rank = type(self).rank - tensor_rank = len(tensor.shape) - if type_rank is not None and type_rank != tensor_rank: - raise ValueError( - f"Cannot create {type(self)}(tensor({tensor.shape})): mismatched symbolic rank" - ) - self._tensor = tensor - self.rank = tensor_rank - - def __repr__(self): - return f"{type(self)}({self._tensor})" - - def __setitem__(self, key, item): - self._tensor.__setitem__(key, item) - - def __getitem__(self, key): - return self._tensor.__getitem__(key) - - -def _make_shaped_kernel_buffer( - cls: Type[KernelBuffer], symbolic_shape: tuple[SymbolDef] -): - class ShapedKernelBuffer(KernelBuffer, symbolic_shape=symbolic_shape): - ... - - return ShapedKernelBuffer diff --git a/tests/kernel/types_test.py b/tests/kernel/indexing_test.py similarity index 89% rename from tests/kernel/types_test.py rename to tests/kernel/indexing_test.py index e824c9478..b73a11a47 100644 --- a/tests/kernel/types_test.py +++ b/tests/kernel/indexing_test.py @@ -2,7 +2,7 @@ import torch -from shark_turbine.kernel.lang.types import * +from shark_turbine.kernel._support.indexing import * M = sym.M N = sym.N @@ -61,6 +61,10 @@ def testKernelBufferInstance(self): self.assertEqual(1, kb.rank) self.assertEqual((M,), kb.symbolic_shape) + def testUsageAndElementTypeInstance(self): + T = InputBuffer[M].of(torch.float16) + self.assertEqual("InputBuffer[M].of(torch.float16)", repr(T)) + if __name__ == "__main__": unittest.main() diff --git a/tests/kernel/vector_codegen_test.py b/tests/kernel/vector_codegen_test.py new file mode 100644 index 000000000..866f3ca99 --- /dev/null +++ b/tests/kernel/vector_codegen_test.py @@ -0,0 +1,48 @@ +import logging +import unittest + +import torch + +import shark_turbine.kernel as tk + +from shark_turbine.kernel.compiler import ( + builder, + vector_codegen, +) +from shark_turbine.kernel._support import ( + indexing, +) + +M = tk.lang.sym.M +K = tk.lang.sym.K + + +class Test(unittest.TestCase): + # This test is using the compiler "the hard way" until we have all of the + # API layering in place. + def testIotaFx(self): + @tk.gen.thread(M) + def iota_kernel(out: tk.lang.OutputBuffer[M]): + i = tk.lang.program_id(0) + out[i] = i + + gm = iota_kernel._trace.gm + print(gm.graph) + mb = builder.ModuleBuilder() + with indexing.IndexingContext() as idxc: + idxc.bind_constant(M, 17) + + sig = vector_codegen.Signature() + sig.add_from_graph_placeholders(gm.graph) + sig.add_grid(iota_kernel.grid_type) + print(sig) + emitter = vector_codegen.ThreadEmitter(mb, iota_kernel.grid_type, sig) + emitter.emit_graph(gm.graph) + emitter.finish() + print(mb.module_op.get_asm()) + mb.module_op.verify() + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main()