Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use immutable objects when dealing with loopy callables #4002

Merged
merged 4 commits into from
Feb 5, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions pyop2/codegen/rep2loopy.py
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@
import numpy
from dataclasses import dataclass

from immutabledict import immutabledict
import loopy
from loopy.symbolic import SubArrayRef
from loopy.expression import dtype_to_type_context
@@ -71,15 +72,15 @@ def symbol_mangler(kernel, name):
class PetscCallable(loopy.ScalarCallable):

def with_types(self, arg_id_to_dtype, callables_table):
new_arg_id_to_dtype = arg_id_to_dtype.copy()
new_arg_id_to_dtype = dict(arg_id_to_dtype)
return (self.copy(
name_in_target=self.name,
arg_id_to_dtype=new_arg_id_to_dtype), callables_table)
arg_id_to_dtype=immutabledict(new_arg_id_to_dtype)), callables_table)

def with_descrs(self, arg_id_to_descr, callables_table):
from loopy.kernel.function_interface import ArrayArgDescriptor
from loopy.kernel.array import FixedStrideArrayDimTag
new_arg_id_to_descr = arg_id_to_descr.copy()
new_arg_id_to_descr = dict(arg_id_to_descr)
for i, des in arg_id_to_descr.items():
# petsc takes 1D arrays as arguments
if isinstance(des, ArrayArgDescriptor):
@@ -88,7 +89,7 @@ def with_descrs(self, arg_id_to_descr, callables_table):
for i in range(len(des.shape)))
new_arg_id_to_descr[i] = des.copy(dim_tags=dim_tags)

return (self.copy(arg_id_to_descr=new_arg_id_to_descr),
return (self.copy(arg_id_to_descr=immutabledict(new_arg_id_to_descr)),
callables_table)

def generate_preambles(self, target):
@@ -143,7 +144,7 @@ def with_types(self, arg_id_to_dtype, callables_table):
dtypes[-1] = NumpyType(dtypes[0].dtype)

return (self.copy(name_in_target=self.name_in_target,
arg_id_to_dtype=dtypes),
arg_id_to_dtype=immutabledict(dtypes)),
callables_table)

def emit_call_insn(self, insn, target, expression_to_code_mapper):
@@ -222,15 +223,15 @@ def __init__(self, name, parameters, arg_id_to_dtype=None, arg_id_to_descr=None,
object.__setattr__(self, "parameters", tuple(parameters))

def with_types(self, arg_id_to_dtype, callables_table):
new_arg_id_to_dtype = arg_id_to_dtype.copy()
new_arg_id_to_dtype = dict(arg_id_to_dtype)
return self.copy(
name_in_target=self.name,
arg_id_to_dtype=new_arg_id_to_dtype), callables_table
arg_id_to_dtype=immutabledict(new_arg_id_to_dtype)), callables_table

def with_descrs(self, arg_id_to_descr, callables_table):
from loopy.kernel.function_interface import ArrayArgDescriptor
from loopy.kernel.array import FixedStrideArrayDimTag
new_arg_id_to_descr = arg_id_to_descr.copy()
new_arg_id_to_descr = dict(arg_id_to_descr)
for i, des in arg_id_to_descr.items():
# 1D arrays
if isinstance(des, ArrayArgDescriptor):
@@ -242,7 +243,7 @@ def with_descrs(self, arg_id_to_descr, callables_table):
for i in range(len(des.shape))
)
new_arg_id_to_descr[i] = des.copy(dim_tags=dim_tags)
return (self.copy(arg_id_to_descr=new_arg_id_to_descr), callables_table)
return (self.copy(arg_id_to_descr=immutabledict(new_arg_id_to_descr)), callables_table)

def emit_call_insn(self, insn, target, expression_to_code_mapper):
# reorder arguments, e.g. a,c = f(b,d) to f(a,b,c,d)