Skip to content

Commit

Permalink
clean up types a bit
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener committed Jan 29, 2025
1 parent 7b0d862 commit 2449f53
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 11 deletions.
6 changes: 2 additions & 4 deletions loopy/kernel/function_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,8 +348,7 @@ def __init__(self,
try:
hash(arg_id_to_dtype)
except TypeError:
if arg_id_to_dtype is None:
arg_id_to_dtype = {}
assert arg_id_to_dtype is not None
arg_id_to_dtype = constantdict(arg_id_to_dtype)
warn("arg_id_to_dtype passed to InKernelCallable was not hashable. "
"This usage is deprecated and will stop working in 2026.",
Expand All @@ -358,8 +357,7 @@ def __init__(self,
try:
hash(arg_id_to_descr)
except TypeError:
if arg_id_to_descr is None:
arg_id_to_descr = {}
assert arg_id_to_descr is not None
arg_id_to_descr = constantdict(arg_id_to_descr)
warn("arg_id_to_descr passed to InKernelCallable was not hashable. "
"This usage is deprecated and will stop working in 2026.",
Expand Down
5 changes: 4 additions & 1 deletion loopy/target/c/c_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Sequence

import numpy as np

from constantdict import constantdict

from codepy.jit import compile_from_string
from codepy.toolchain import GCCToolchain, ToolchainGuessError, guess_toolchain

Expand Down Expand Up @@ -500,7 +503,7 @@ def get_wrapper_generator(self):

@memoize_method
def translation_unit_info(self,
arg_to_dtype: Mapping[str, LoopyType] | None = None) -> _KernelInfo:
arg_to_dtype: constantdict[str, LoopyType] | None = None) -> _KernelInfo:
t_unit = self.get_typed_and_scheduled_translation_unit(arg_to_dtype)

from loopy.codegen import generate_code_v2
Expand Down
10 changes: 5 additions & 5 deletions loopy/target/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,7 @@ def check_for_required_array_arguments(self, input_args):
"your argument.")

def get_typed_and_scheduled_translation_unit_uncached(
self, arg_to_dtype: Mapping[str, LoopyType] | None
self, arg_to_dtype: constantdict[str, LoopyType] | None
) -> TranslationUnit:
t_unit = self.t_unit

Expand All @@ -827,15 +827,15 @@ def get_typed_and_scheduled_translation_unit_uncached(
# FIXME: This is not so nice. This transfers types from the
# subarrays of sep-tagged arrays to the 'main' array, because
# type inference fails otherwise.
mm = dict(arg_to_dtype)
mm = arg_to_dtype.mutate()
for name, sep_info in self.sep_info.items():
if entry_knl.arg_dict[name].dtype is None:
for sep_name in sep_info.subarray_names.values():
if sep_name in arg_to_dtype:
mm[name] = arg_to_dtype[sep_name]
del mm[sep_name]

arg_to_dtype = constantdict(mm)
arg_to_dtype = mm.finish()

from loopy.kernel.tools import add_dtypes
t_unit = t_unit.with_kernel(add_dtypes(entry_knl, arg_to_dtype))
Expand All @@ -854,7 +854,7 @@ def get_typed_and_scheduled_translation_unit_uncached(
return t_unit

def get_typed_and_scheduled_translation_unit(
self, arg_to_dtype: Mapping[str, LoopyType] | None
self, arg_to_dtype: constantdict[str, LoopyType] | None
) -> TranslationUnit:
from loopy import CACHING_ENABLED

Expand Down Expand Up @@ -904,7 +904,7 @@ def get_highlighted_code(self, entrypoint, arg_to_dtype=None, code=None):

def get_code(
self, entrypoint: str,
arg_to_dtype: Mapping[str, LoopyType] | None = None) -> str:
arg_to_dtype: constantdict[str, LoopyType] | None = None) -> str:
kernel = self.get_typed_and_scheduled_translation_unit(arg_to_dtype)

from loopy.codegen import generate_code_v2
Expand Down
4 changes: 3 additions & 1 deletion loopy/target/pyopencl_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@

import numpy as np

from constantdict import constantdict

from pytools import memoize_method
from pytools.codegen import CodeGenerator, Indentation

Expand Down Expand Up @@ -311,7 +313,7 @@ def get_wrapper_generator(self):
@memoize_method
def translation_unit_info(
self,
arg_to_dtype: Mapping[str, LoopyType] | None = None) -> _KernelInfo:
arg_to_dtype: constantdict[str, LoopyType] | None = None) -> _KernelInfo:
t_unit = self.get_typed_and_scheduled_translation_unit(arg_to_dtype)

# FIXME: now just need to add the types to the arguments
Expand Down

0 comments on commit 2449f53

Please sign in to comment.