Skip to content

Commit

Permalink
fix: pass immutable arg_id_to_dtype to InKernelCallables
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl committed Jan 30, 2025
1 parent a3c7eea commit 1bb84f8
Show file tree
Hide file tree
Showing 9 changed files with 80 additions and 54 deletions.
8 changes: 5 additions & 3 deletions examples/python/call-external.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
from constantdict import constantdict

import loopy as lp
from loopy.diagnostic import LoopyError
Expand Down Expand Up @@ -30,9 +31,10 @@ def with_types(self, arg_id_to_dtype, callables_table):
"types")

return (self.copy(name_in_target=name_in_target,
arg_id_to_dtype={0: vec_dtype,
1: vec_dtype,
-1: vec_dtype}),
arg_id_to_dtype=constantdict({
0: vec_dtype,
1: vec_dtype,
-1: vec_dtype})),
callables_table)

def with_descrs(self, arg_id_to_descr, callables_table):
Expand Down
21 changes: 11 additions & 10 deletions loopy/kernel/function_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,9 +565,9 @@ def with_types(self, arg_id_to_dtype, callables_table):
"the function %s." % (self.name))

def with_descrs(self, arg_id_to_descr, clbl_inf_ctx):

arg_id_to_descr[-1] = ValueArgDescriptor()
return (self.copy(arg_id_to_descr=arg_id_to_descr),
new_arg_id_to_descr = dict(arg_id_to_descr)
new_arg_id_to_descr[-1] = ValueArgDescriptor()
return (self.copy(arg_id_to_descr=constantdict(new_arg_id_to_descr)),
clbl_inf_ctx)

def get_hw_axes_sizes(self, arg_id_to_arg, space, callables_table):
Expand Down Expand Up @@ -782,14 +782,15 @@ def with_descrs(self, arg_id_to_descr, clbl_inf_ctx):
# arg_id_to_descr expressions provided are from the caller's namespace,
# need to register

new_arg_id_to_descr = dict(arg_id_to_descr)
kw_to_pos, pos_to_kw = get_kw_pos_association(self.subkernel)

kw_to_callee_idx = {arg.name: i
for i, arg in enumerate(self.subkernel.args)}

new_args = self.subkernel.args[:]

for arg_id, descr in arg_id_to_descr.items():
for arg_id, descr in new_arg_id_to_descr.items():
if isinstance(arg_id, int):
arg_id = pos_to_kw[arg_id]

Expand Down Expand Up @@ -837,15 +838,15 @@ def with_descrs(self, arg_id_to_descr, clbl_inf_ctx):
for arg in subkernel.args:
kw = arg.name
if isinstance(arg, ArrayBase):
arg_id_to_descr[kw] = (
new_arg_id_to_descr[kw] = (
ArrayArgDescriptor(shape=arg.shape,
dim_tags=arg.dim_tags,
address_space=arg.address_space))
else:
assert isinstance(arg, ValueArg)
arg_id_to_descr[kw] = ValueArgDescriptor()
new_arg_id_to_descr[kw] = ValueArgDescriptor()

arg_id_to_descr[kw_to_pos[kw]] = arg_id_to_descr[kw]
new_arg_id_to_descr[kw_to_pos[kw]] = new_arg_id_to_descr[kw]

# }}}

Expand Down Expand Up @@ -879,8 +880,8 @@ def with_added_arg(self, arg_dtype, arg_descr):
arg_id_to_descr[kw_to_pos[var_name]] = arg_descr

return (self.copy(subkernel=subknl,
arg_id_to_dtype=arg_id_to_dtype,
arg_id_to_descr=arg_id_to_descr),
arg_id_to_dtype=constantdict(arg_id_to_dtype),
arg_id_to_descr=constantdict(arg_id_to_descr)),
var_name)

else:
Expand All @@ -902,7 +903,7 @@ def with_packing_for_args(self):
address_space=AddressSpace.GLOBAL)

return self.copy(subkernel=self.subkernel,
arg_id_to_descr=arg_id_to_descr)
arg_id_to_descr=constantdict(arg_id_to_descr))

def get_used_hw_axes(self, callables_table):
gsize, lsize = self.subkernel.get_grid_size_upper_bounds(callables_table,
Expand Down
9 changes: 5 additions & 4 deletions loopy/library/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from typing import TYPE_CHECKING

import numpy as np
from constantdict import constantdict

from loopy.diagnostic import LoopyError
from loopy.kernel.function_interface import ScalarCallable
Expand All @@ -38,12 +39,12 @@

class MakeTupleCallable(ScalarCallable):
def with_types(self, arg_id_to_dtype, callables_table):
new_arg_id_to_dtype = arg_id_to_dtype.copy()
new_arg_id_to_dtype = dict(arg_id_to_dtype)
for i in range(len(arg_id_to_dtype)):
if i in arg_id_to_dtype and arg_id_to_dtype[i] is not None:
new_arg_id_to_dtype[-i-1] = new_arg_id_to_dtype[i]

return (self.copy(arg_id_to_dtype=new_arg_id_to_dtype,
return (self.copy(arg_id_to_dtype=constantdict(new_arg_id_to_dtype),
name_in_target="loopy_make_tuple"), callables_table)

def with_descrs(self, arg_id_to_descr, callables_table):
Expand All @@ -52,7 +53,7 @@ def with_descrs(self, arg_id_to_descr, callables_table):
(-id-1, ValueArgDescriptor()) for id in arg_id_to_descr.keys()}

return (
self.copy(arg_id_to_descr=new_arg_id_to_descr),
self.copy(arg_id_to_descr=constantdict(new_arg_id_to_descr)),
callables_table)


Expand All @@ -63,7 +64,7 @@ def with_types(self, arg_id_to_dtype, callables_table):
if dtype is not None}
new_arg_id_to_dtype[-1] = NumpyType(np.int32)

return (self.copy(arg_id_to_dtype=new_arg_id_to_dtype),
return (self.copy(arg_id_to_dtype=constantdict(new_arg_id_to_dtype)),
callables_table)

def emit_call(self, expression_to_code_mapper, expression, target):
Expand Down
7 changes: 4 additions & 3 deletions loopy/library/random123.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from typing import TYPE_CHECKING

import numpy as np
from constantdict import constantdict
from mako.template import Template

from pymbolic.typing import not_none
Expand Down Expand Up @@ -221,7 +222,7 @@ def with_types(self, arg_id_to_dtype, callables_table):
new_arg_id_to_dtype = {-1: ctr_dtype, -2: ctr_dtype, 0: ctr_dtype, 1:
key_dtype}
return (
self.copy(arg_id_to_dtype=new_arg_id_to_dtype,
self.copy(arg_id_to_dtype=constantdict(new_arg_id_to_dtype),
name_in_target=fn+"_gen"),
callables_table)

Expand All @@ -230,15 +231,15 @@ def with_types(self, arg_id_to_dtype, callables_table):
rng_variant.width),
-2: ctr_dtype, 0: ctr_dtype, 1:
key_dtype}
return self.copy(arg_id_to_dtype=new_arg_id_to_dtype,
return self.copy(arg_id_to_dtype=constantdict(new_arg_id_to_dtype),
name_in_target=name), callables_table

elif name == fn + "_f64":
new_arg_id_to_dtype = {-1: target.vector_dtype(NumpyType(np.float64),
rng_variant.width),
-2: ctr_dtype, 0: ctr_dtype, 1:
key_dtype}
return self.copy(arg_id_to_dtype=new_arg_id_to_dtype,
return self.copy(arg_id_to_dtype=constantdict(new_arg_id_to_dtype),
name_in_target=name), callables_table

return (self.copy(arg_id_to_dtype=arg_id_to_dtype),
Expand Down
9 changes: 5 additions & 4 deletions loopy/library/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from typing import TYPE_CHECKING

import numpy as np
from constantdict import constantdict

from pymbolic import var
from pymbolic.primitives import expr_dataclass
Expand Down Expand Up @@ -580,21 +581,21 @@ def with_types(self, arg_id_to_dtype, callables_table):
index_dtype = arg_id_to_dtype[1]
result_dtypes = self.name.reduction_op.result_dtypes(scalar_dtype, # pylint: disable=no-member
index_dtype)
new_arg_id_to_dtype = arg_id_to_dtype.copy()
new_arg_id_to_dtype = dict(arg_id_to_dtype)
new_arg_id_to_dtype[-1] = result_dtypes[0]
new_arg_id_to_dtype[-2] = result_dtypes[1]
name_in_target = self.name.reduction_op.prefix(scalar_dtype, # pylint: disable=no-member
index_dtype) + "_op"

return self.copy(arg_id_to_dtype=new_arg_id_to_dtype,
return self.copy(arg_id_to_dtype=constantdict(new_arg_id_to_dtype),
name_in_target=name_in_target), callables_table

def with_descrs(self, arg_id_to_descr, callables_table):
from loopy.kernel.function_interface import ValueArgDescriptor
new_arg_id_to_descr = arg_id_to_descr.copy()
new_arg_id_to_descr = dict(arg_id_to_descr)
new_arg_id_to_descr[-1] = ValueArgDescriptor()
return (
self.copy(arg_id_to_descr=arg_id_to_descr),
self.copy(arg_id_to_descr=constantdict(arg_id_to_descr)),
callables_table)


Expand Down
2 changes: 1 addition & 1 deletion loopy/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ def map_call(self, expr, expn_state, assignees=None):
# }}}

# specializing the function according to the parameter description
new_clbl, self.clbl_inf_ctx = clbl.with_descrs(arg_id_to_descr,
new_clbl, self.clbl_inf_ctx = clbl.with_descrs(constantdict(arg_id_to_descr),
self.clbl_inf_ctx)

self.clbl_inf_ctx, new_func_id = (self.clbl_inf_ctx
Expand Down
25 changes: 14 additions & 11 deletions loopy/target/c/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from typing import TYPE_CHECKING, Any, Sequence, cast

import numpy as np
from constantdict import constantdict

import pymbolic.primitives as p
from cgen import (
Expand Down Expand Up @@ -563,9 +564,9 @@ def with_types(self, arg_id_to_dtype, callables_table):

return (
self.copy(name_in_target=name,
arg_id_to_dtype={
arg_id_to_dtype=constantdict({
0: NumpyType(dtype),
-1: NumpyType(result_dtype)}),
-1: NumpyType(result_dtype)})),
callables_table)

# binary functions
Expand Down Expand Up @@ -607,7 +608,7 @@ def with_types(self, arg_id_to_dtype, callables_table):
dtype = NumpyType(dtype)
return (
self.copy(name_in_target=name,
arg_id_to_dtype={-1: dtype, 0: dtype, 1: dtype}),
arg_id_to_dtype=constantdict({-1: dtype, 0: dtype, 1: dtype})),
callables_table)
elif name in ["max", "min"]:

Expand All @@ -632,9 +633,10 @@ def with_types(self, arg_id_to_dtype, callables_table):

return (
self.copy(name_in_target=f"lpy_{name}_{dtype.name}",
arg_id_to_dtype={-1: NumpyType(dtype),
0: NumpyType(dtype),
1: NumpyType(dtype)}),
arg_id_to_dtype=constantdict({
-1: NumpyType(dtype),
0: NumpyType(dtype),
1: NumpyType(dtype)})),
callables_table)
elif name == "isnan":
for id in arg_id_to_dtype:
Expand Down Expand Up @@ -662,9 +664,9 @@ def with_types(self, arg_id_to_dtype, callables_table):
return (
self.copy(
name_in_target=name,
arg_id_to_dtype={
arg_id_to_dtype=constantdict({
0: NumpyType(dtype),
-1: NumpyType(np.int32)}),
-1: NumpyType(np.int32)})),
callables_table)

def generate_preambles(self, target):
Expand Down Expand Up @@ -738,9 +740,10 @@ def with_types(self, arg_id_to_dtype, callables_table):

return (
self.copy(name_in_target=name_in_target,
arg_id_to_dtype={-1: arg_id_to_dtype[1],
0: NumpyType(np.int32),
1: arg_id_to_dtype[1]}),
arg_id_to_dtype=constantdict({
-1: arg_id_to_dtype[1],
0: NumpyType(np.int32),
1: arg_id_to_dtype[1]})),
callables_table)
else:
raise NotImplementedError(f"with_types for '{name}'")
Expand Down
42 changes: 28 additions & 14 deletions loopy/target/opencl.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from typing import TYPE_CHECKING, Literal, Sequence

import numpy as np
from constantdict import constantdict

from pymbolic import var
from pytools import memoize_method
Expand Down Expand Up @@ -217,9 +218,10 @@ def with_types(self, arg_id_to_dtype, callables_table):
# OpenCL C 2.2, Section 6.13.3: abs returns *u*gentype
from loopy.types import to_unsigned_dtype
return (self.copy(name_in_target=name,
arg_id_to_dtype={
arg_id_to_dtype=constantdict({
0: NumpyType(dtype),
-1: NumpyType(to_unsigned_dtype(dtype))}),
-1: NumpyType(to_unsigned_dtype(dtype))
})),
callables_table)
elif dtype.kind == "f":
name = "fabs"
Expand Down Expand Up @@ -251,8 +253,10 @@ def with_types(self, arg_id_to_dtype, callables_table):

return (
self.copy(name_in_target=name,
arg_id_to_dtype={0: NumpyType(dtype), -1:
NumpyType(dtype)}),
arg_id_to_dtype=constantdict({
0: NumpyType(dtype),
-1: NumpyType(dtype)
})),
callables_table)

# }}}
Expand Down Expand Up @@ -283,7 +287,9 @@ def with_types(self, arg_id_to_dtype, callables_table):
dtype = NumpyType(dtype)
return (
self.copy(name_in_target=name,
arg_id_to_dtype={-1: dtype, 0: dtype, 1: dtype}),
arg_id_to_dtype=constantdict({
-1: dtype, 0: dtype, 1: dtype
})),
callables_table)

elif name in ["max", "min"]:
Expand All @@ -305,7 +311,9 @@ def with_types(self, arg_id_to_dtype, callables_table):
dtype = NumpyType(common_dtype)
return (
self.copy(name_in_target=name,
arg_id_to_dtype={-1: dtype, 0: dtype, 1: dtype}),
arg_id_to_dtype=constantdict({
-1: dtype, 0: dtype, 1: dtype
})),
callables_table)
else:
# Unsupported type.
Expand All @@ -328,8 +336,9 @@ def with_types(self, arg_id_to_dtype, callables_table):
dtype = arg_id_to_dtype[0]
scalar_dtype, _offset, _field_name = dtype.numpy_dtype.fields["s0"]
return (
self.copy(name_in_target=name, arg_id_to_dtype={-1:
NumpyType(scalar_dtype), 0: dtype, 1: dtype}),
self.copy(name_in_target=name, arg_id_to_dtype=constantdict({
-1: NumpyType(scalar_dtype), 0: dtype, 1: dtype
})),
callables_table)

elif name == "pow":
Expand All @@ -352,8 +361,11 @@ def with_types(self, arg_id_to_dtype, callables_table):

return (
self.copy(name_in_target=name,
arg_id_to_dtype={-1: result_dtype,
0: common_dtype, 1: common_dtype}),
arg_id_to_dtype=constantdict({
-1: result_dtype,
0: common_dtype,
1: common_dtype
})),
callables_table)

elif name in _CL_SIMPLE_MULTI_ARG_FUNCTIONS:
Expand All @@ -379,8 +391,9 @@ def with_types(self, arg_id_to_dtype, callables_table):
raise LoopyError("%s does not support complex numbers"
% name)

updated_arg_id_to_dtype = {id: NumpyType(dtype) for id in range(-1,
num_args)}
updated_arg_id_to_dtype = constantdict({
id: NumpyType(dtype) for id in range(-1, num_args)
})

return (
self.copy(name_in_target=name,
Expand Down Expand Up @@ -409,8 +422,9 @@ def with_types(self, arg_id_to_dtype, callables_table):
NumpyType(dtype), count)

return (
self.copy(name_in_target="(%s%d) " % (base_tp_name, count),
arg_id_to_dtype=updated_arg_id_to_dtype),
self.copy(
name_in_target="(%s%d) " % (base_tp_name, count),
arg_id_to_dtype=constantdict(updated_arg_id_to_dtype)),
callables_table)

# does not satisfy any of the conditions needed for specialization.
Expand Down
Loading

0 comments on commit 1bb84f8

Please sign in to comment.