Skip to content

Commit

Permalink
Evaluate dtype_to_typeclass at use time (#1494)
Browse files Browse the repository at this point in the history
Prior to this PR, `dtype_to_typeclass` was evaluated at import time.
This means that the configuration entry `default_data_types` could not
be modified after importing dace in a meaningful way.
  • Loading branch information
tbennun authored Jan 4, 2024
1 parent bfe6923 commit 4d49452
Show file tree
Hide file tree
Showing 10 changed files with 97 additions and 69 deletions.
4 changes: 2 additions & 2 deletions dace/codegen/cppunparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,9 +746,9 @@ def _Repr(self, t):
def _Num(self, t):
t_n = t.value if sys.version_info >= (3, 8) else t.n
repr_n = repr(t_n)
# For complex values, use DTYPE_TO_TYPECLASS dictionary
# For complex values, use ``dtype_to_typeclass``
if isinstance(t_n, complex):
dtype = dtypes.DTYPE_TO_TYPECLASS[complex]
dtype = dtypes.dtype_to_typeclass(complex)

# Handle large integer values
if isinstance(t_n, int):
Expand Down
79 changes: 50 additions & 29 deletions dace/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ class typeclass(object):
2. Enabling declaration syntax: `dace.float32[M,N]`
3. Enabling extensions such as `dace.struct` and `dace.vector`
"""

def __init__(self, wrapped_type, typename=None):
# Convert python basic types
if isinstance(wrapped_type, str):
Expand Down Expand Up @@ -600,6 +601,7 @@ def result_type_of(lhs, *rhs):

class opaque(typeclass):
""" A data type for an opaque object, useful for C bindings/libnodes, i.e., MPI_Request. """

def __init__(self, typename):
self.type = typename
self.ctype = typename
Expand Down Expand Up @@ -635,6 +637,7 @@ class pointer(typeclass):
Example use:
`dace.pointer(dace.struct(x=dace.float32, y=dace.float32))`. """

def __init__(self, wrapped_typeclass):
self._typeclass = wrapped_typeclass
self.type = wrapped_typeclass.type
Expand Down Expand Up @@ -680,6 +683,7 @@ class vector(typeclass):
Example use: `dace.vector(dace.float32, 4)` becomes float4.
"""

def __init__(self, dtype: typeclass, vector_length: int):
self.vtype = dtype
self.type = dtype.type
Expand Down Expand Up @@ -737,6 +741,7 @@ class stringtype(pointer):
Python/generated code marshalling.
Used internally when `str` types are given
"""

def __init__(self):
super().__init__(int8)

Expand All @@ -756,6 +761,7 @@ class struct(typeclass):
Example use: `dace.struct(a=dace.int32, b=dace.float64)`.
"""

def __init__(self, name, **fields_and_types):
# self._data = fields_and_types
self.type = ctypes.Structure
Expand Down Expand Up @@ -859,6 +865,7 @@ class pyobject(opaque):
It cannot be used inside a DaCe program, but can be passed back to other Python callbacks.
Use with caution, and ensure the value is not removed by the garbage collector or the program will crash.
"""

def __init__(self):
super().__init__('pyobject')
self.bytes = ctypes.sizeof(ctypes.c_void_p)
Expand Down Expand Up @@ -892,6 +899,7 @@ def example(A: dace.float64[20], constant: dace.compiletime):
In the above code, ``constant`` will be replaced with its value at call time
during parsing.
"""

@staticmethod
def __descriptor__():
raise ValueError('All compile-time arguments must be provided in order to compile the SDFG ahead-of-time.')
Expand All @@ -914,6 +922,7 @@ class callback(typeclass):
"""
Looks like ``dace.callback([None, <some_native_type>], *types)``
"""

def __init__(self, return_types, *variadic_args):
from dace import data
if return_types is None:
Expand Down Expand Up @@ -1240,31 +1249,39 @@ class Typeclasses(aenum.AutoNumberEnum):
complex128 = complex128


DTYPE_TO_TYPECLASS = {
bool: typeclass(bool),
int: typeclass(int),
float: typeclass(float),
complex: typeclass(complex),
numpy.bool_: bool_,
numpy.int8: int8,
numpy.int16: int16,
numpy.int32: int32,
numpy.int64: int64,
numpy.intc: int32,
numpy.uint8: uint8,
numpy.uint16: uint16,
numpy.uint32: uint32,
numpy.uint64: uint64,
numpy.uintc: uint32,
numpy.float16: float16,
numpy.float32: float32,
numpy.float64: float64,
numpy.complex64: complex64,
numpy.complex128: complex128,
# FIXME
numpy.longlong: int64,
numpy.ulonglong: uint64
}
_bool = bool


def dtype_to_typeclass(dtype=None):
DTYPE_TO_TYPECLASS = {
_bool: typeclass(_bool),
int: typeclass(int),
float: typeclass(float),
complex: typeclass(complex),
numpy.bool_: bool_,
numpy.int8: int8,
numpy.int16: int16,
numpy.int32: int32,
numpy.int64: int64,
numpy.intc: int32,
numpy.uint8: uint8,
numpy.uint16: uint16,
numpy.uint32: uint32,
numpy.uint64: uint64,
numpy.uintc: uint32,
numpy.float16: float16,
numpy.float32: float32,
numpy.float64: float64,
numpy.complex64: complex64,
numpy.complex128: complex128,
# FIXME
numpy.longlong: int64,
numpy.ulonglong: uint64
}
if dtype is None:
return DTYPE_TO_TYPECLASS
return DTYPE_TO_TYPECLASS[dtype]


# Since this overrides the builtin bool, this should be after the
# DTYPE_TO_TYPECLASS dictionary
Expand Down Expand Up @@ -1354,6 +1371,7 @@ def isallowed(var, allow_recursive=False):
class DebugInfo:
""" Source code location identifier of a node/edge in an SDFG. Used for
IDE and debugging purposes. """

def __init__(self, start_line, start_column=0, end_line=-1, end_column=0, filename=None):
self.start_line = start_line
self.end_line = end_line if end_line >= 0 else start_line
Expand Down Expand Up @@ -1397,6 +1415,7 @@ def json_to_typeclass(obj, context=None):
def paramdec(dec):
""" Parameterized decorator meta-decorator. Enables using `@decorator`,
`@decorator()`, and `@decorator(...)` with the same function. """

@wraps(dec)
def layer(*args, **kwargs):
from dace import data
Expand Down Expand Up @@ -1478,20 +1497,22 @@ def can_allocate(storage: StorageType, schedule: ScheduleType):
# Host-only allocation
if storage in [StorageType.CPU_Heap, StorageType.CPU_Pinned, StorageType.CPU_ThreadLocal]:
return schedule in [
ScheduleType.CPU_Multicore, ScheduleType.CPU_Persistent, ScheduleType.Sequential, ScheduleType.MPI, ScheduleType.GPU_Default
ScheduleType.CPU_Multicore, ScheduleType.CPU_Persistent, ScheduleType.Sequential, ScheduleType.MPI,
ScheduleType.GPU_Default
]

# GPU-global memory
if storage is StorageType.GPU_Global:
return schedule in [
ScheduleType.CPU_Multicore, ScheduleType.CPU_Persistent, ScheduleType.Sequential, ScheduleType.MPI, ScheduleType.GPU_Default
ScheduleType.CPU_Multicore, ScheduleType.CPU_Persistent, ScheduleType.Sequential, ScheduleType.MPI,
ScheduleType.GPU_Default
]

# FPGA-global memory
if storage is StorageType.FPGA_Global:
return schedule in [
ScheduleType.CPU_Multicore, ScheduleType.CPU_Persistent, ScheduleType.Sequential, ScheduleType.MPI, ScheduleType.FPGA_Device,
ScheduleType.GPU_Default
ScheduleType.CPU_Multicore, ScheduleType.CPU_Persistent, ScheduleType.Sequential, ScheduleType.MPI,
ScheduleType.FPGA_Device, ScheduleType.GPU_Default
]

# FPGA-local memory
Expand Down
8 changes: 4 additions & 4 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3240,7 +3240,7 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False):
raise DaceSyntaxError(self, target, 'Variable "{}" used before definition'.format(name))

new_data, rng = None, None
dtype_keys = tuple(dtypes.DTYPE_TO_TYPECLASS.keys())
dtype_keys = tuple(dtypes.dtype_to_typeclass().keys())
if not (result in self.sdfg.symbols or symbolic.issymbolic(result) or isinstance(result, dtype_keys) or
(isinstance(result, str) and result in self.sdfg.arrays)):
raise DaceSyntaxError(
Expand Down Expand Up @@ -4653,14 +4653,14 @@ def visit_Num(self, node: NumConstant):
if isinstance(node.n, bool):
return dace.bool_(node.n)
if isinstance(node.n, (int, float, complex)):
return dtypes.DTYPE_TO_TYPECLASS[type(node.n)](node.n)
return dtypes.dtype_to_typeclass(type(node.n))(node.n)
return node.n

def visit_Constant(self, node: ast.Constant):
if isinstance(node.value, bool):
return dace.bool_(node.value)
if isinstance(node.value, (int, float, complex)):
return dtypes.DTYPE_TO_TYPECLASS[type(node.value)](node.value)
return dtypes.dtype_to_typeclass(type(node.value))(node.value)
if isinstance(node.value, (str, bytes)):
return StringLiteral(node.value)
return node.value
Expand Down Expand Up @@ -4745,7 +4745,7 @@ def _gettype(self, opnode: ast.AST) -> List[Tuple[str, str]]:
result.append((operand, type(self.sdfg.arrays[operand])))
elif isinstance(operand, str) and operand in self.scope_arrays:
result.append((operand, type(self.scope_arrays[operand])))
elif isinstance(operand, tuple(dtypes.DTYPE_TO_TYPECLASS.keys())):
elif isinstance(operand, tuple(dtypes.dtype_to_typeclass().keys())):
if isinstance(operand, (bool, numpy.bool_)):
result.append((operand, 'BoolConstant'))
else:
Expand Down
38 changes: 19 additions & 19 deletions dace/frontend/python/replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def _numpy_full(pv: ProgramVisitor,
"""
is_data = False
if isinstance(fill_value, (Number, np.bool_)):
vtype = dtypes.DTYPE_TO_TYPECLASS[type(fill_value)]
vtype = dtypes.dtype_to_typeclass(type(fill_value))
elif isinstance(fill_value, sp.Expr):
vtype = _sym_type(fill_value)
else:
Expand Down Expand Up @@ -546,10 +546,10 @@ def _arange(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, *args, **kwargs):
if 'dtype' in kwargs and kwargs['dtype'] != None:
dtype = kwargs['dtype']
if not isinstance(dtype, dtypes.typeclass):
dtype = dtypes.DTYPE_TO_TYPECLASS[dtype]
dtype = dtypes.dtype_to_typeclass(dtype)
outname, outarr = sdfg.add_temp_transient(shape, dtype)
else:
dtype = dtypes.DTYPE_TO_TYPECLASS[type(shape[0])]
dtype = dtypes.dtype_to_typeclass(type(shape[0]))
outname, outarr = sdfg.add_temp_transient(shape, dtype)

state.add_mapped_tasklet(name="_numpy_arange_",
Expand Down Expand Up @@ -1076,8 +1076,8 @@ def _array_array_where(visitor: ProgramVisitor,
left_arr = sdfg.arrays.get(left_operand, None)
right_arr = sdfg.arrays.get(right_operand, None)

left_type = left_arr.dtype if left_arr else dtypes.DTYPE_TO_TYPECLASS[type(left_operand)]
right_type = right_arr.dtype if right_arr else dtypes.DTYPE_TO_TYPECLASS[type(right_operand)]
left_type = left_arr.dtype if left_arr else dtypes.dtype_to_typeclass(type(left_operand))
right_type = right_arr.dtype if right_arr else dtypes.dtype_to_typeclass(type(right_operand))

# Implicit Python coversion implemented as casting
arguments = [cond_arr, left_arr or left_type, right_arr or right_type]
Expand Down Expand Up @@ -1356,11 +1356,11 @@ def _np_result_type(nptypes):
# Fix for np.result_type returning platform-dependent types,
# e.g. np.longlong
restype = np.result_type(*nptypes)
if restype.type not in dtypes.DTYPE_TO_TYPECLASS.keys():
for k in dtypes.DTYPE_TO_TYPECLASS.keys():
if restype.type not in dtypes.dtype_to_typeclass().keys():
for k in dtypes.dtype_to_typeclass().keys():
if k == restype.type:
return dtypes.DTYPE_TO_TYPECLASS[k]
return dtypes.DTYPE_TO_TYPECLASS[restype.type]
return dtypes.dtype_to_typeclass(k)
return dtypes.dtype_to_typeclass(restype.type)


def _sym_type(expr: Union[symbolic.symbol, sp.Basic]) -> dtypes.typeclass:
Expand Down Expand Up @@ -1393,7 +1393,7 @@ def _result_type(arguments: Sequence[Union[str, Number, symbolic.symbol, sp.Basi
datatypes.append(arg.dtype)
dtypes_for_result.append(_representative_num(arg.dtype))
elif isinstance(arg, (Number, np.bool_)):
datatypes.append(dtypes.DTYPE_TO_TYPECLASS[type(arg)])
datatypes.append(dtypes.dtype_to_typeclass(type(arg)))
dtypes_for_result.append(arg)
elif symbolic.issymbolic(arg):
datatypes.append(_sym_type(arg))
Expand Down Expand Up @@ -1668,13 +1668,13 @@ def _array_const_binop(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, le
left_shape = left_arr.shape
storage = left_arr.storage
right_arr = None
right_type = dtypes.DTYPE_TO_TYPECLASS[type(right_operand)]
right_type = dtypes.dtype_to_typeclass(type(right_operand))
right_shape = [1]
arguments = [left_arr, right_operand]
tasklet_args = ['__in1', f'({str(right_operand)})']
else:
left_arr = None
left_type = dtypes.DTYPE_TO_TYPECLASS[type(left_operand)]
left_type = dtypes.dtype_to_typeclass(type(left_operand))
left_shape = [1]
right_arr = sdfg.arrays[right_operand]
right_type = right_arr.dtype
Expand Down Expand Up @@ -2229,7 +2229,7 @@ def _matmult(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: str, op

type1 = arr1.dtype.type
type2 = arr2.dtype.type
restype = dace.DTYPE_TO_TYPECLASS[np.result_type(type1, type2).type]
restype = dace.dtype_to_typeclass(np.result_type(type1, type2).type)

op3, arr3 = sdfg.add_temp_transient(output_shape, restype, arr1.storage)

Expand Down Expand Up @@ -3517,7 +3517,7 @@ def implement_ufunc(visitor: ProgramVisitor, ast_node: ast.Call, sdfg: SDFG, sta
ufunc_impl['operator'])
if 'dtype' in kwargs.keys():
dtype = kwargs['dtype']
if dtype in dtypes.DTYPE_TO_TYPECLASS.keys():
if dtype in dtypes.dtype_to_typeclass().keys():
result_type = dtype

# Create output data (if needed)
Expand Down Expand Up @@ -3709,7 +3709,7 @@ def implement_ufunc_reduce(visitor: ProgramVisitor, ast_node: ast.Call, sdfg: SD
datadesc = sdfg.arrays[arg]
result_type = datadesc.dtype
elif isinstance(arg, (Number, np.bool_)):
result_type = dtypes.DTYPE_TO_TYPECLASS[type(arg)]
result_type = dtypes.dtype_to_typeclass(type(arg))
elif isinstance(arg, sp.Basic):
result_type = _sym_type(arg)

Expand Down Expand Up @@ -4018,7 +4018,7 @@ def implement_ufunc_outer(visitor: ProgramVisitor, ast_node: ast.Call, sdfg: SDF
ufunc_impl['operator'])
if 'dtype' in kwargs.keys():
dtype = kwargs['dtype']
if dtype in dtypes.DTYPE_TO_TYPECLASS.keys():
if dtype in dtypes.dtype_to_typeclass().keys():
result_type = dtype

# Create output data (if needed)
Expand Down Expand Up @@ -4412,9 +4412,9 @@ def _make_datatype_converter(typeclass: str):
if typeclass == "bool":
dtype = dace.bool
elif typeclass in {"int", "float", "complex"}:
dtype = dtypes.DTYPE_TO_TYPECLASS[eval(typeclass)]
dtype = dtypes.dtype_to_typeclass(eval(typeclass))
else:
dtype = dtypes.DTYPE_TO_TYPECLASS[eval("np.{}".format(typeclass))]
dtype = dtypes.dtype_to_typeclass(eval("np.{}".format(typeclass)))

@oprepo.replaces(typeclass)
@oprepo.replaces("dace.{}".format(typeclass))
Expand Down Expand Up @@ -4711,7 +4711,7 @@ def _cupy_full(pv: ProgramVisitor,
the fill value.
"""
if isinstance(fill_value, (Number, np.bool_)):
vtype = dtypes.DTYPE_TO_TYPECLASS[type(fill_value)]
vtype = dtypes.dtype_to_typeclass(type(fill_value))
elif isinstance(fill_value, sp.Expr):
vtype = _sym_type(fill_value)
else:
Expand Down
8 changes: 4 additions & 4 deletions dace/libraries/blas/nodes/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ def _cast_to_dtype_str(value, dtype: dace.dtypes.typeclass) -> str:
cast_value = complex(value)

return "dace.{type}({real}, {imag})".format(
type=dace.DTYPE_TO_TYPECLASS[dtype].to_string(),
type=dace.dtype_to_typeclass(dtype).to_string(),
real=cast_value.real,
imag=cast_value.imag,
)
else:
return "dace.{}({})".format(dace.DTYPE_TO_TYPECLASS[dtype].to_string(), value)
return "dace.{}({})".format(dace.dtype_to_typeclass(dtype).to_string(), value)


@dace.library.expansion
Expand All @@ -52,7 +52,7 @@ def make_sdfg(node, parent_state, parent_sdfg):

dtype_a = outer_array_a.dtype.type
dtype_b = outer_array_b.dtype.type
dtype_c = dace.DTYPE_TO_TYPECLASS[np.result_type(dtype_a, dtype_b).type]
dtype_c = dace.dtype_to_typeclass(np.result_type(dtype_a, dtype_b).type)

if node.transA:
trans_shape_a = list(reversed(shape_a))
Expand Down Expand Up @@ -518,7 +518,7 @@ def expansion(node, parent_state, parent_sdfg, num_pes=32, tile_size_m=None):

dtype_a = outer_array_a.dtype.type
dtype_b = outer_array_b.dtype.type
dtype_c = dace.DTYPE_TO_TYPECLASS[np.result_type(dtype_a, dtype_b).type]
dtype_c = dace.dtype_to_typeclass(np.result_type(dtype_a, dtype_b).type)
shape_c = (shape_a[0], shape_b[1])
if node.transA:
raise NotImplementedError("GEMM FPGA expansion not implemented for transposed A.")
Expand Down
Loading

0 comments on commit 4d49452

Please sign in to comment.