From 06822443d1b883125fa7ff4e59981e498ca4fb4b Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Mon, 10 Feb 2025 22:59:09 -0600 Subject: [PATCH] ISPC result cast generation: specify uniform/varying --- loopy/target/ispc.py | 50 +++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 47 insertions(+), 3 deletions(-) diff --git a/loopy/target/ispc.py b/loopy/target/ispc.py index dca9abbac..096cb2cd6 100644 --- a/loopy/target/ispc.py +++ b/loopy/target/ispc.py @@ -24,8 +24,9 @@ THE SOFTWARE. """ - -from typing import TYPE_CHECKING, Sequence, cast +import operator +from functools import reduce +from typing import TYPE_CHECKING, Iterable, Sequence, cast import numpy as np @@ -37,7 +38,7 @@ from loopy.diagnostic import LoopyError from loopy.kernel.data import AddressSpace, ArrayArg, TemporaryVariable -from loopy.symbolic import Literal +from loopy.symbolic import CombineMapper, Literal from loopy.target.c import CFamilyASTBuilder, CFamilyTarget from loopy.target.c.codegen.expression import ExpressionToCExpressionMapper @@ -50,6 +51,26 @@ from loopy.typing import Expression +class IsVaryingMapper(CombineMapper[bool, []]): + def combine(self, values: Iterable[bool]) -> bool: + return reduce(operator.or_, values, False) + + def map_constant(self, expr): + return False + + def map_group_hw_index(self, expr): + return False + + def map_local_hw_index(self, expr): + if expr.axis == 0: + return True + else: + raise LoopyError("ISPC only supports one local axis") + + def map_variable(self, expr): + return False + + # {{{ expression mapper class ExprToISPCExprMapper(ExpressionToCExpressionMapper): @@ -142,6 +163,29 @@ def map_subscript(self, expr, type_context): return super().map_subscript( expr, type_context) + def wrap_in_typecast(self, actual_type: LoopyType, needed_type: LoopyType, s): + raise NotImplementedError("wrap_in_typecast needs uniform-ness information " + "for ispc") + + def rec(self, expr, type_context=None, needed_type: LoopyType | None = None): # type: ignore[override] + result = super().rec(expr, type_context) + + if needed_type is None: + return result + else: + actual_type = self.infer_type(expr) + if actual_type != needed_type: + # FIXME: problematic: quadratic complexity + is_varying = IsVaryingMapper()(expr) + registry = self.codegen_state.ast_builder.target.get_dtype_registry() + cast = var("(" + f"{'varying' if is_varying else 'uniform'} " + f"{registry.dtype_to_ctype(needed_type)}" + ") ") + return cast(result) + + return result + # }}}