Skip to content

Commit

Permalink
Remove the long-deprecated symbol.get/set methods (#1523)
Browse files Browse the repository at this point in the history
This also fixes FPGA CI, which now runs in parallel and depended on that
deprecated functionality.

---------

Co-authored-by: Alexandros Nikolaos Ziogas <[email protected]>
  • Loading branch information
tbennun and alexnick83 authored Feb 20, 2024
1 parent f28e960 commit c92ecc5
Show file tree
Hide file tree
Showing 141 changed files with 1,022 additions and 1,347 deletions.
55 changes: 18 additions & 37 deletions dace/codegen/compiled_sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,8 @@ def __init__(self, library_filename, program_name):
:param program_name: Name of the DaCe program (for use in finding
the stub library loader).
"""
self._stub_filename = os.path.join(
os.path.dirname(os.path.realpath(library_filename)),
f'libdacestub_{program_name}.{Config.get("compiler", "library_extension")}')
self._stub_filename = os.path.join(os.path.dirname(os.path.realpath(library_filename)),
f'libdacestub_{program_name}.{Config.get("compiler", "library_extension")}')
self._library_filename = os.path.realpath(library_filename)
self._stub = None
self._lib = None
Expand Down Expand Up @@ -219,7 +218,6 @@ def __init__(self, sdfg, lib: ReloadableDLL, argnames: List[str] = None):
self.has_gpu_code = True
break


def get_exported_function(self, name: str, restype=None) -> Optional[Callable[..., Any]]:
"""
Tries to find a symbol by name in the compiled SDFG, and convert it to a callable function
Expand All @@ -233,7 +231,6 @@ def get_exported_function(self, name: str, restype=None) -> Optional[Callable[..
except KeyError: # Function not found
return None


def get_state_struct(self) -> ctypes.Structure:
""" Attempt to parse the SDFG source code and extract the state struct. This method will parse the first
consecutive entries in the struct that are pointers. As soon as a non-pointer or other unparseable field is
Expand All @@ -247,7 +244,6 @@ def get_state_struct(self) -> ctypes.Structure:

return ctypes.cast(self._libhandle, ctypes.POINTER(self._try_parse_state_struct())).contents


def _try_parse_state_struct(self) -> Optional[Type[ctypes.Structure]]:
from dace.codegen.targets.cpp import mangle_dace_state_struct_name # Avoid import cycle
# the path of the main sdfg file containing the state struct
Expand Down Expand Up @@ -375,7 +371,6 @@ def _get_error_text(self, result: Union[str, int]) -> str:
else:
return result


def __call__(self, *args, **kwargs):
"""
Forwards the Python call to the compiled ``SDFG``.
Expand All @@ -400,13 +395,12 @@ def __call__(self, *args, **kwargs):
elif len(args) > 0 and self.argnames is not None:
kwargs.update(
# `_construct_args` will handle all of its arguments as kwargs.
{aname: arg for aname, arg in zip(self.argnames, args)}
)
argtuple, initargtuple = self._construct_args(kwargs) # Missing arguments will be detected here.
# Return values are cached in `self._lastargs`.
{aname: arg
for aname, arg in zip(self.argnames, args)})
argtuple, initargtuple = self._construct_args(kwargs) # Missing arguments will be detected here.
# Return values are cached in `self._lastargs`.
return self.fast_call(argtuple, initargtuple, do_gpu_check=True)


def fast_call(
self,
callargs: Tuple[Any, ...],
Expand Down Expand Up @@ -455,15 +449,13 @@ def fast_call(
self._lib.unload()
raise


def __del__(self):
if self._initialized is True:
self.finalize()
self._initialized = False
self._libhandle = ctypes.c_void_p(0)
self._lib.unload()


def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]:
"""
Main function that controls argument construction for calling
Expand All @@ -486,7 +478,7 @@ def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]:
typedict = self._typedict
if len(kwargs) > 0:
# Construct mapping from arguments to signature
arglist = []
arglist = []
argtypes = []
argnames = []
for a in sig:
Expand Down Expand Up @@ -536,10 +528,9 @@ def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]:
'you are doing, you can override this error in the '
'configuration by setting compiler.allow_view_arguments '
'to True.')
elif (not isinstance(atype, (dt.Array, dt.Structure)) and
not isinstance(atype.dtype, dtypes.callback) and
not isinstance(arg, (atype.dtype.type, sp.Basic)) and
not (isinstance(arg, symbolic.symbol) and arg.dtype == atype.dtype)):
elif (not isinstance(atype, (dt.Array, dt.Structure)) and not isinstance(atype.dtype, dtypes.callback)
and not isinstance(arg, (atype.dtype.type, sp.Basic))
and not (isinstance(arg, symbolic.symbol) and arg.dtype == atype.dtype)):
is_int = isinstance(arg, int)
if is_int and atype.dtype.type == np.int64:
pass
Expand Down Expand Up @@ -573,29 +564,23 @@ def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]:
# Retain only the element datatype for upcoming checks and casts
arg_ctypes = tuple(at.dtype.as_ctypes() for at in argtypes)

constants = self.sdfg.constants
callparams = tuple(
(actype(arg.get())
if isinstance(arg, symbolic.symbol)
else arg, actype, atype, aname
)
for arg, actype, atype, aname in zip(arglist, arg_ctypes, argtypes, argnames)
if not (symbolic.issymbolic(arg) and (hasattr(arg, 'name') and arg.name in constants))
)
constants = self.sdfg.constants
callparams = tuple((arg, actype, atype, aname)
for arg, actype, atype, aname in zip(arglist, arg_ctypes, argtypes, argnames)
if not (symbolic.issymbolic(arg) and (hasattr(arg, 'name') and arg.name in constants)))

symbols = self._free_symbols
initargs = tuple(
actype(arg) if not isinstance(arg, ctypes._SimpleCData) else arg
for arg, actype, atype, aname in callparams
if aname in symbols
)
actype(arg) if not isinstance(arg, ctypes._SimpleCData) else arg for arg, actype, atype, aname in callparams
if aname in symbols)

try:
# Replace arrays with their base host/device pointers
newargs = [None] * len(callparams)
for i, (arg, actype, atype, _) in enumerate(callparams):
if dtypes.is_array(arg):
newargs[i] = ctypes.c_void_p(_array_interface_ptr(arg, atype.storage)) # `c_void_p` is subclass of `ctypes._SimpleCData`.
newargs[i] = ctypes.c_void_p(_array_interface_ptr(
arg, atype.storage)) # `c_void_p` is subclass of `ctypes._SimpleCData`.
elif not isinstance(arg, (ctypes._SimpleCData)):
newargs[i] = actype(arg)
else:
Expand All @@ -607,11 +592,9 @@ def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]:
self._lastargs = newargs, initargs
return self._lastargs


def clear_return_values(self):
self._create_new_arrays = True


def _create_array(self, _: str, dtype: np.dtype, storage: dtypes.StorageType, shape: Tuple[int],
strides: Tuple[int], total_size: int):
ndarray = np.ndarray
Expand All @@ -636,7 +619,6 @@ def ndarray(*args, buffer=None, **kwargs):
# Create an array with the properties of the SDFG array
return ndarray(shape, dtype, buffer=zeros(total_size, dtype), strides=strides)


def _initialize_return_values(self, kwargs):
# Obtain symbol values from arguments and constants
syms = dict()
Expand Down Expand Up @@ -687,7 +669,6 @@ def _initialize_return_values(self, kwargs):
arr = self._create_array(*shape_desc)
self._return_arrays.append(arr)


def _convert_return_values(self):
# Return the values as they would be from a Python function
if self._return_arrays is None or len(self._return_arrays) == 0:
Expand Down
7 changes: 2 additions & 5 deletions dace/frontend/python/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,9 @@


def ndarray(shape, dtype=numpy.float64, *args, **kwargs):
""" Returns a numpy ndarray where all symbols have been evaluated to
numbers and types are converted to numpy types. """
repldict = {sym: sym.get() for sym in symbolic.symlist(shape).values()}
new_shape = [int(s.subs(repldict) if symbolic.issymbolic(s) else s) for s in shape]
""" Returns a numpy ndarray where all types are converted to numpy types. """
new_dtype = dtype.type if isinstance(dtype, dtypes.typeclass) else dtype
return numpy.ndarray(shape=new_shape, dtype=new_dtype, *args, **kwargs)
return numpy.ndarray(shape=shape, dtype=new_dtype, *args, **kwargs)


stream: Type[Deque[T]] = deque
Expand Down
10 changes: 1 addition & 9 deletions dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2125,16 +2125,8 @@ def specialize(self, symbols: Dict[str, Any]):
:param symbols: Values to specialize.
"""
# Set symbol values to add
syms = {
# If symbols are passed, extract the value. If constants are
# passed, use them directly.
name: val.get() if isinstance(val, dace.symbolic.symbol) else val
for name, val in symbols.items()
}

# Update constants
for k, v in syms.items():
for k, v in symbols.items():
self.add_constant(str(k), v)

def is_loaded(self) -> bool:
Expand Down
46 changes: 7 additions & 39 deletions dace/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,10 @@ def __new__(cls, name=None, dtype=DEFAULT_SYMBOL_TYPE, **assumptions):

self.dtype = dtype
self._constraints = []
self.value = None
return self

def set(self, value):
warnings.warn('symbol.set is deprecated, use keyword arguments', DeprecationWarning)
if value is not None:
# First, check constraints
self.check_constraints(value)

self.value = self.dtype(value)

def __getstate__(self):
return dict(self.assumptions0, **{'value': self.value, 'dtype': self.dtype, '_constraints': self._constraints})
return dict(self.assumptions0, **{'dtype': self.dtype, '_constraints': self._constraints})

def _eval_subs(self, old, new):
"""
Expand All @@ -85,15 +76,6 @@ def _eval_subs(self, old, new):
except AttributeError:
return None

def is_initialized(self):
return self.value is not None

def get(self):
warnings.warn('symbol.get is deprecated, use keyword arguments', DeprecationWarning)
if self.value is None:
raise UnboundLocalError('Uninitialized symbol value for \'' + self.name + '\'')
return self.value

def set_constraints(self, constraint_list):
try:
iter(constraint_list)
Expand Down Expand Up @@ -141,9 +123,6 @@ def check_constraints(self, value):
if fail is not None:
raise RuntimeError('Value %s invalidates constraint %s for symbol %s' % (str(value), str(fail), self.name))

def get_or_return(self, uninitialized_ret):
return self.value or uninitialized_ret


class SymExpr(object):
""" Symbolic expressions with support for an overapproximation expression.
Expand Down Expand Up @@ -287,13 +266,6 @@ def __gt__(self, other):
SymbolicType = Union[sympy.Basic, SymExpr]


def symvalue(val):
""" Returns the symbol value if it is a symbol. """
if isinstance(val, symbol):
return val.get()
return val


# http://stackoverflow.com/q/3844948/
def _checkEqualIvo(lst):
return not lst or lst.count(lst[0]) == len(lst)
Expand Down Expand Up @@ -333,9 +305,8 @@ def symlist(values):
return result


def evaluate(expr: Union[sympy.Basic, int, float],
symbols: Dict[Union[symbol, str], Union[int, float]]) -> \
Union[int, float, numpy.number]:
def evaluate(expr: Union[sympy.Basic, int, float], symbols: Dict[Union[symbol, str],
Union[int, float]]) -> Union[int, float, numpy.number]:
"""
Evaluates an expression to a constant based on a mapping from symbols
to values.
Expand All @@ -356,9 +327,7 @@ def evaluate(expr: Union[sympy.Basic, int, float],
return expr

# Evaluate all symbols
syms = {(sname if isinstance(sname, sympy.Symbol) else symbol(sname)):
sval.get() if isinstance(sval, symbol) else sval
for sname, sval in symbols.items()}
syms = {(sname if isinstance(sname, sympy.Symbol) else symbol(sname)): sval for sname, sval in symbols.items()}

# Filter out `None` values, callables, and iterables but not strings (for SymPy 1.12)
syms = {
Expand Down Expand Up @@ -1028,7 +997,7 @@ def visit_IfExp(self, node):
self.visit(node.orelse)],
keywords=[])
return ast.copy_location(new_node, node)

def visit_Subscript(self, node):
if isinstance(node.value, ast.Attribute):
attr = ast.Subscript(value=ast.Name(id=node.value.attr, ctx=ast.Load()), slice=node.slice, ctx=ast.Load())
Expand Down Expand Up @@ -1405,8 +1374,7 @@ def equal(a: SymbolicType, b: SymbolicType, is_length: bool = True) -> Union[boo
return sympy.ask(sympy.Q.is_true(sympy.Eq(*args)))


def symbols_in_code(code: str, potential_symbols: Set[str] = None,
symbols_to_ignore: Set[str] = None) -> Set[str]:
def symbols_in_code(code: str, potential_symbols: Set[str] = None, symbols_to_ignore: Set[str] = None) -> Set[str]:
"""
Tokenizes a code string for symbols and returns a set thereof.
Expand All @@ -1419,7 +1387,7 @@ def symbols_in_code(code: str, potential_symbols: Set[str] = None,
if potential_symbols is not None and len(potential_symbols) == 0:
# Don't bother tokenizing for an empty set of potential symbols
return set()

tokens = set(re.findall(_NAME_TOKENS, code))
if potential_symbols is not None:
tokens &= potential_symbols
Expand Down
38 changes: 17 additions & 21 deletions samples/fpga/gemv_fpga.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,12 @@ def make_outer_compute_state(sdfg):
return state


def make_sdfg(specialize):
def make_sdfg(specialize, N, M):

if specialize:
name = "gemv_transposed_{}x{}".format(N.get(), M.get())
name = "gemv_transposed_{}x{}".format(N, M)
else:
name = "gemv_transposed_{}xM".format(N.get())
name = "gemv_transposed_{}xM".format(N)

sdfg = dace.SDFG(name)

Expand All @@ -143,29 +143,25 @@ def run_gemv(n: int, m: int, specialize: bool):

print("==== Program start ====")

N.set(n)
if specialize:
print("Specializing M...")
M.set(m)

gemv = make_sdfg(specialize)
gemv.specialize(dict(N=N))
gemv = make_sdfg(specialize, n, m)
gemv.specialize(dict(N=n))

if not specialize:
M.set(m)
else:
gemv.specialize(dict(M=M))
if specialize:
gemv.specialize(dict(M=m))

print("Running GEMV {}x{} ({}specialized)".format(N.get(), M.get(), ("" if specialize else "not ")))
print("Running GEMV {}x{} ({}specialized)".format(n, m, ("" if specialize else "not ")))

A = dace.ndarray([M, N], dtype=dtype)
x = dace.ndarray([M], dtype=dtype)
y = dace.ndarray([N], dtype=dtype)
A = dace.ndarray([m, n], dtype=dtype)
x = dace.ndarray([m], dtype=dtype)
y = dace.ndarray([n], dtype=dtype)

# Intialize: randomize A, x and y
# A[:, :] = np.random.rand(M.get(), N.get()).astype(dtype.type)
# x[:] = np.random.rand(M.get()).astype(dtype.type)
# y[:] = np.random.rand(N.get()).astype(dtype.type)
# A[:, :] = np.random.rand(M, N).astype(dtype.type)
# x[:] = np.random.rand(M).astype(dtype.type)
# y[:] = np.random.rand(N).astype(dtype.type)
A[:, :] = 1
x[:] = 1
y[:] = 0
Expand All @@ -179,9 +175,9 @@ def run_gemv(n: int, m: int, specialize: bool):
if specialize:
gemv(A=A, x=x, y=x)
else:
gemv(A=A, M=M, x=x, y=y)
gemv(A=A, M=m, x=x, y=y)

residual = np.linalg.norm(y - regression) / (N.get() * M.get())
residual = np.linalg.norm(y - regression) / (n * m)
print("Residual:", residual)
diff = np.abs(y - regression)
wrong_elements = np.transpose(np.nonzero(diff >= 0.01))
Expand All @@ -191,7 +187,7 @@ def run_gemv(n: int, m: int, specialize: bool):
if residual >= 0.01 or highest_diff >= 0.01:
print("Verification failed!")
print("Residual: {}".format(residual))
print("Incorrect elements: {} / {}".format(wrong_elements.shape[0], (N.get() * M.get())))
print("Incorrect elements: {} / {}".format(wrong_elements.shape[0], (n * m)))
print("Highest difference: {}".format(highest_diff))
print("** Result:\n", y)
print("** Reference:\n", regression)
Expand Down
Loading

0 comments on commit c92ecc5

Please sign in to comment.