Skip to content

Commit

Permalink
ISPC result cast generation: specify uniform/varying
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Feb 11, 2025
1 parent 9272929 commit 0682244
Showing 1 changed file with 47 additions and 3 deletions.
50 changes: 47 additions & 3 deletions loopy/target/ispc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
THE SOFTWARE.
"""


from typing import TYPE_CHECKING, Sequence, cast
import operator
from functools import reduce
from typing import TYPE_CHECKING, Iterable, Sequence, cast

import numpy as np

Expand All @@ -37,7 +38,7 @@

from loopy.diagnostic import LoopyError
from loopy.kernel.data import AddressSpace, ArrayArg, TemporaryVariable
from loopy.symbolic import Literal
from loopy.symbolic import CombineMapper, Literal
from loopy.target.c import CFamilyASTBuilder, CFamilyTarget
from loopy.target.c.codegen.expression import ExpressionToCExpressionMapper

Expand All @@ -50,6 +51,26 @@
from loopy.typing import Expression


class IsVaryingMapper(CombineMapper[bool, []]):
def combine(self, values: Iterable[bool]) -> bool:
return reduce(operator.or_, values, False)

def map_constant(self, expr):
return False

def map_group_hw_index(self, expr):
return False

def map_local_hw_index(self, expr):
if expr.axis == 0:
return True
else:
raise LoopyError("ISPC only supports one local axis")

def map_variable(self, expr):
return False


# {{{ expression mapper

class ExprToISPCExprMapper(ExpressionToCExpressionMapper):
Expand Down Expand Up @@ -142,6 +163,29 @@ def map_subscript(self, expr, type_context):
return super().map_subscript(
expr, type_context)

def wrap_in_typecast(self, actual_type: LoopyType, needed_type: LoopyType, s):
raise NotImplementedError("wrap_in_typecast needs uniform-ness information "
"for ispc")

def rec(self, expr, type_context=None, needed_type: LoopyType | None = None): # type: ignore[override]
result = super().rec(expr, type_context)

if needed_type is None:
return result
else:
actual_type = self.infer_type(expr)
if actual_type != needed_type:
# FIXME: problematic: quadratic complexity
is_varying = IsVaryingMapper()(expr)
registry = self.codegen_state.ast_builder.target.get_dtype_registry()
cast = var("("
f"{'varying' if is_varying else 'uniform'} "
f"{registry.dtype_to_ctype(needed_type)}"
") ")
return cast(result)

return result

# }}}


Expand Down

0 comments on commit 0682244

Please sign in to comment.