From a4e029116e36b60c838ee2130ea8983b18bc153a Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Thu, 27 Jun 2024 16:12:09 +0200 Subject: [PATCH] Improve SDFG work-depth analysis and add SDFG simulated operational intensity analysis (#1607) This PR is here to merge the reviewed PR #1495, which has remained inactive for a long time with minor comments open. The comments have been addressed here and merge conflicts have been resolved. --------- Co-authored-by: Cliff Hodel Co-authored-by: Cliff Hodel <111381329+hodelcl@users.noreply.github.com> Co-authored-by: Cliff Hodel --- .../assumptions.py | 4 +- .../helpers.py | 14 +- .../performance_evaluation/op_in_helpers.py | 283 ++++++++ .../operational_intensity.py | 639 ++++++++++++++++++ .../work_depth.py | 90 ++- tests/sdfg/operational_intensity_test.py | 148 ++++ tests/sdfg/work_depth_test.py | 330 +++++++++ tests/sdfg/work_depth_tests.py | 262 ------- 8 files changed, 1473 insertions(+), 297 deletions(-) rename dace/sdfg/{work_depth_analysis => performance_evaluation}/assumptions.py (98%) rename dace/sdfg/{work_depth_analysis => performance_evaluation}/helpers.py (96%) create mode 100644 dace/sdfg/performance_evaluation/op_in_helpers.py create mode 100644 dace/sdfg/performance_evaluation/operational_intensity.py rename dace/sdfg/{work_depth_analysis => performance_evaluation}/work_depth.py (91%) create mode 100644 tests/sdfg/operational_intensity_test.py create mode 100644 tests/sdfg/work_depth_test.py delete mode 100644 tests/sdfg/work_depth_tests.py diff --git a/dace/sdfg/work_depth_analysis/assumptions.py b/dace/sdfg/performance_evaluation/assumptions.py similarity index 98% rename from dace/sdfg/work_depth_analysis/assumptions.py rename to dace/sdfg/performance_evaluation/assumptions.py index 6e311cde0c..ec8c61ef73 100644 --- a/dace/sdfg/work_depth_analysis/assumptions.py +++ b/dace/sdfg/performance_evaluation/assumptions.py @@ -153,7 +153,7 @@ def propagate_assumptions_equal_symbols(condensed_assumptions): equality_subs1.update({sym: sp.Symbol(uf.find(sym))}) equality_subs2 = {} - # In a second step, each symbol gets replace with its equal number (if present) + # In a second step, each symbol gets replaced with its equal number (if present) # using equality_subs2. for sym, assum in condensed_assumptions.items(): for e in assum.equal: @@ -182,7 +182,7 @@ def parse_assumptions(assumptions, array_symbols): Parses a list of assumptions into substitution dictionaries. Firstly, it gathers all assumptions and keeps only the strongest ones. Afterwards it constructs two substitution dicts for the equality assumptions: First dict for symbol==symbol assumptions; second dict for symbol==number assumptions. - The other assumptions get handles by N tuples of substitution dicts (N = max number of concurrent + The other assumptions get handled by N tuples of substitution dicts (N = max number of concurrent assumptions for a single symbol). Each tuple is responsible for at most one assumption for each symbol. First dict in the tuple substitutes the symbol with the assumption; second dict restores the initial symbol. diff --git a/dace/sdfg/work_depth_analysis/helpers.py b/dace/sdfg/performance_evaluation/helpers.py similarity index 96% rename from dace/sdfg/work_depth_analysis/helpers.py rename to dace/sdfg/performance_evaluation/helpers.py index 31d3661509..552e2917cc 100644 --- a/dace/sdfg/work_depth_analysis/helpers.py +++ b/dace/sdfg/performance_evaluation/helpers.py @@ -214,6 +214,10 @@ def get_backedges(graph: nx.DiGraph, return backedges +class LoopExtractionError(Exception): + pass + + def find_loop_guards_tails_exits(sdfg_nx: nx.DiGraph): """ Detects loops in a SDFG. For each loop, it identifies (node, oNode, exit). @@ -241,15 +245,15 @@ def find_loop_guards_tails_exits(sdfg_nx: nx.DiGraph): # sanity check: if sdfg_nx.in_degree(artificial_end_node) == 0: - raise ValueError('No end node could be determined in the SDFG') + raise LoopExtractionError('No end node could be determined in the SDFG') # compute dominators and backedges iDoms = nx.immediate_dominators(sdfg_nx, start) - allDom, domTree = get_domtree(sdfg_nx, start, iDoms) + allDom, _ = get_domtree(sdfg_nx, start, iDoms) reversed_sdfg_nx = sdfg_nx.reverse() iPostDoms = nx.immediate_dominators(reversed_sdfg_nx, artificial_end_node) - allPostDoms, postDomTree = get_domtree(reversed_sdfg_nx, artificial_end_node, iPostDoms) + _, postDomTree = get_domtree(reversed_sdfg_nx, artificial_end_node, iPostDoms) backedges = get_backedges(sdfg_nx, start) backedgesDstDict = {} @@ -297,7 +301,7 @@ def find_loop_guards_tails_exits(sdfg_nx: nx.DiGraph): exitCandidates.add(succ) if len(exitCandidates) == 0: - raise ValueError('failed to find any exit nodes') + raise LoopExtractionError('failed to find any exit nodes') elif len(exitCandidates) > 1: # Find the exit candidate that sits highest up in the # postdominator tree (i.e., has the lowest level). @@ -323,7 +327,7 @@ def find_loop_guards_tails_exits(sdfg_nx: nx.DiGraph): if len(minSet) > 0: exitCandidates = minSet else: - raise ValueError('failed to find exit minSet') + raise LoopExtractionError('failed to find exit minSet') # now we have a triple (node, oNode, exitCandidates) nodes_oNodes_exits.append((node, oNode, exitCandidates)) diff --git a/dace/sdfg/performance_evaluation/op_in_helpers.py b/dace/sdfg/performance_evaluation/op_in_helpers.py new file mode 100644 index 0000000000..6f4481868f --- /dev/null +++ b/dace/sdfg/performance_evaluation/op_in_helpers.py @@ -0,0 +1,283 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +""" Contains class CacheLineTracker which keeps track of all arrays of an SDFG and their cache line position +and class AccessStack which which corresponds to the stack used to compute the stack distance. +Further, provides a curve fitting method and plotting function. """ + +import warnings +from dace.data import Array +import sympy as sp +from collections import deque +from scipy.optimize import curve_fit +import numpy as np +from dace import symbol + + +class CacheLineTracker: + """ A CacheLineTracker maps data container accesses to the corresponding accessed cache line. """ + + def __init__(self, L) -> None: + self.array_info = {} + self.start_lines = {} + self.next_free_line = 0 + self.L = L + + def add_array(self, name: str, a: Array, mapping): + if name not in self.start_lines: + # new array encountered + self.array_info[name] = a + self.start_lines[name] = self.next_free_line + # increase next_free_line + self.next_free_line += (a.total_size.subs(mapping) * a.dtype.bytes + self.L - 1) // self.L # ceil division + + def cache_line_id(self, name: str, access: [int], mapping): + arr = self.array_info[name] + one_d_index = 0 + for dim in range(len(access)): + i = access[dim] + one_d_index += (i + sp.sympify(arr.offset[dim]).subs(mapping)) * sp.sympify(arr.strides[dim]).subs(mapping) + + # divide by L to get the cache line id + return self.start_lines[name] + (one_d_index * arr.dtype.bytes) // self.L + + def copy(self): + new_clt = CacheLineTracker(self.L) + new_clt.array_info = dict(self.array_info) + new_clt.start_lines = dict(self.start_lines) + new_clt.next_free_line = self.next_free_line + return new_clt + + +class Node: + + def __init__(self, val: int, n=None) -> None: + self.v = val + self.next = n + + +class AccessStack: + """ A stack of cache line ids. For each memory access, we search the corresponding cache line id + in the stack, report its distance and move it to the top of the stack. If the id was not found, + we report a distance of -1. """ + + def __init__(self, C) -> None: + self.top = None + self.num_calls = 0 + self.length = 0 + self.C = C + + def touch(self, id): + self.num_calls += 1 + curr = self.top + prev = None + found = False + distance = 0 + while curr is not None: + # check if we found id + if curr.v == id: + # take curr node out + if prev is not None: + prev.next = curr.next + curr.next = self.top + self.top = curr + + found = True + break + + # iterate further + prev = curr + curr = curr.next + distance += 1 + + if not found: + # we accessed this cache line for the first time ever + self.top = Node(id, self.top) + self.length += 1 + distance = -1 + + return distance + + def in_cache_as_list(self): + """ + Returns a list of cache ids currently in cache. Index 0 is the most recently used. + """ + res = deque() + curr = self.top + dist = 0 + while curr is not None and dist < self.C: + res.append(curr.v) + curr = curr.next + dist += 1 + return res + + def debug_print(self): + # prints the whole stack + print('\n') + curr = self.top + while curr is not None: + print(curr.v, end=', ') + curr = curr.next + print('\n') + + def copy(self): + new_stack = AccessStack(self.C) + cache_content = self.in_cache_as_list() + if len(cache_content) > 0: + new_top_value = cache_content.popleft() + new_stack.top = Node(new_top_value) + curr = new_stack.top + for x in cache_content: + curr.next = Node(x) + curr = curr.next + return new_stack + + +def plot(x, work_map, cache_misses, op_in_map, symbol_name, C, L, sympy_f, element, name): + plt = None + try: + import matplotlib.pyplot as plt_import + plt = plt_import + except ModuleNotFoundError: + pass + + if plt is None: + warnings.warn('Plotting only possible with matplotlib installed') + return + + work_map = work_map[element] + cache_misses = cache_misses[element] + op_in_map = op_in_map[element] + sympy_f = sympy_f[element] + + a = np.linspace(1, max(x) + 5, max(x) * 4) + + fig, ax = plt.subplots(1, 2, figsize=(12, 5)) + ax[0].scatter(x, cache_misses, label=f'C={C*L}, L={L}') + b = [] + for curr in a: + b.append(sp.N(sp.sympify(sympy_f).subs(symbol_name, curr))) + ax[0].plot(a, b) + + c = [] + for i, curr in enumerate(x): + if work_map[0].subs(symbol_name, curr) == 0: + c.append(0) + elif (cache_misses[i] * L) == 0: + c.append(9999) + else: + c.append(work_map[0].subs(symbol_name, curr) / (cache_misses[i] * L)) + c = np.array(c).astype(np.float64) + + ax[1].scatter(x, c, label=f'C={C*L}, L={L}') + b = [] + for curr in a: + b.append(sp.N(sp.sympify(op_in_map).subs(symbol_name, curr))) + ax[1].plot(a, b) + + ax[0].set_ylim(bottom=0, top=max(cache_misses) + max(cache_misses) / 10) + ax[0].set_xlim(left=0, right=max(x) + 1) + ax[0].set_xlabel(symbol_name) + ax[0].set_ylabel('Number of Cache Misses') + ax[0].set_title(name) + ax[0].legend(fancybox=True, framealpha=0.5) + + ax[1].set_ylim(bottom=0, top=max(c) + max(c) / 10) + ax[1].set_xlim(left=0, right=max(x) + 1) + ax[1].set_xlabel(symbol_name) + ax[1].set_ylabel('Operational Intensity') + ax[1].set_title(name) + + fig.show() + + +def compute_mape(f, test_x, test_y, test_set_size): + total_error = 0 + for i in range(test_set_size): + pred = f(test_x[i]) + err = abs(test_y[i] - pred) + total_error += err / test_y[i] + return total_error / test_set_size + + +def r_squared(pred, y): + if np.sum(np.square(y - y.mean())) <= 0.0001: + return 1 + return 1 - np.sum(np.square(y - pred)) / np.sum(np.square(y - y.mean())) + + +def find_best_model(x, y, I, J, symbol_name): + """ Find the best model out of all combinations of (i, j) from I and J via leave-one-out cross validation. """ + min_error = None + for i in I: + for j in J: + # current model + if i == 0 and j == 0: + + def f(x, b): + return b * np.ones_like(x) + else: + + def f(x, c, b): + return c * np.power(x, i) * np.power(np.log2(x), j) + b + + error_sum = 0 + for left_out in range(len(x)): + xx = np.delete(x, left_out) + yy = np.delete(y, left_out) + try: + param, _ = curve_fit(f, xx, yy) + + # predict on left out sample + pred = f(x[left_out], *param) + squared_error = np.square(pred - y[left_out]) + error_sum += squared_error + except RuntimeError: + # triggered if no fit was found --> give huge error + error_sum += 999999 + + mean_error = error_sum / len(x) + if min_error is None or mean_error < min_error: + # new best model found + min_error = mean_error + best_i_j = (i, j) + if best_i_j[0] == 0 and best_i_j[1] == 0: + + def f_best(x, b): + return b * np.ones_like(x) + else: + + def f_best(x, c, b): + return c * np.power(x, best_i_j[0]) * np.power(np.log2(x), best_i_j[1]) + b + + # fit best model to all data points + final_p, _ = curve_fit(f_best, x, y) + + def final_f(x): + return f_best(x, *final_p) + + if best_i_j[0] == 0 and best_i_j[1] == 0: + sympy_f = final_p[0] + else: + sympy_f = sp.simplify(final_p[0] * symbol(symbol_name)**best_i_j[0] * + sp.log(symbol(symbol_name), 2)**best_i_j[1] + final_p[1]) + # compute r^2 + r_s = r_squared(final_f(x), y) + return final_f, sympy_f, r_s + + +def fit_curve(x, y, symbol_name): + """ + Fits a function throught the data set. + + :param x: The independent values. + :param y: The dependent values. + :param symbol_name: The name of the SDFG symbol. + """ + x = np.array(x).astype(np.int32) + y = np.array(y).astype(np.float64) + + # model search space + I = [x / 4 for x in range(13)] + J = [0, 1, 2] + final_f, sympy_final_f, r_s = find_best_model(x, y, I, J, symbol_name) + + return final_f, sympy_final_f, r_s diff --git a/dace/sdfg/performance_evaluation/operational_intensity.py b/dace/sdfg/performance_evaluation/operational_intensity.py new file mode 100644 index 0000000000..26eee2f253 --- /dev/null +++ b/dace/sdfg/performance_evaluation/operational_intensity.py @@ -0,0 +1,639 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +""" Analyses the operational intensity of an input SDFG. Can be used as a Python script +or from the VS Code extension. """ + +import argparse +from collections import deque +from dace.sdfg import nodes as nd +from dace import SDFG, SDFGState, dtypes +from typing import Tuple, Dict +import os +import sympy as sp +from copy import deepcopy +from dace.symbolic import pystr_to_symbolic, SymExpr + +from dace.sdfg.performance_evaluation.helpers import get_uuid +from dace.transformation.passes.symbol_ssa import StrictSymbolSSA +from dace.transformation.pass_pipeline import FixedPointPipeline + +from dace.data import Array +from dace.sdfg.performance_evaluation.op_in_helpers import CacheLineTracker, AccessStack, fit_curve, plot, compute_mape +from dace.sdfg.performance_evaluation.work_depth import analyze_sdfg, get_tasklet_work + + +class SymbolRange(): + """ Used to describe an SDFG symbol associated with a range (start, stop, step) of values. """ + + def __init__(self, start_stop_step) -> None: + self.r = range(*start_stop_step) + self.i = iter(self.r) + + def next(self): + try: + r = next(self.i) + except StopIteration: + r = -1 + return r + + def to_list(self): + return list(self.r) + + def max_value(self): + return max(self.to_list()) + + +def update_map(op_in_map, uuid, new_misses, average=True): + if average: + if uuid in op_in_map: + misses, encounters = op_in_map[uuid] + op_in_map[uuid] = (misses + new_misses, encounters + 1) + else: + op_in_map[uuid] = (new_misses, 1) + else: + if uuid in op_in_map: + misses, encounters = op_in_map[uuid] + op_in_map[uuid] = (misses + new_misses, encounters) + else: + op_in_map[uuid] = (new_misses, 1) + + +def calculate_op_in(op_in_map, work_map, stringify=False, assumptions={}): + """ Calculates the operational intensity for each SDFG element from work and bytes loaded. """ + for uuid in op_in_map: + work = work_map[uuid][0].subs(assumptions) + if work == 0 and op_in_map[uuid] == 0: + op_in_map[uuid] = 0 + elif work != 0 and op_in_map[uuid] == 0: + # everything was read from cache --> infinite op_in + op_in_map[uuid] = sp.oo + else: + # op_in > 0 --> divide normally + op_in_map[uuid] = sp.N(work / op_in_map[uuid]) + if stringify: + op_in_map[uuid] = str(op_in_map[uuid]) + + +def mem_accesses_on_path(states): + mem_accesses = 0 + for state in states: + mem_accesses += len(state.read_and_write_sets()) + return mem_accesses + + +def find_states_between(sdfg: SDFG, start_state: SDFGState, end_state: SDFGState): + traversal_q = deque() + traversal_q.append(start_state) + visited = set() + states = [] + while traversal_q: + curr_state = traversal_q.popleft() + if curr_state == end_state: + continue + if curr_state not in visited: + visited.add(curr_state) + states.append(curr_state) + for e in sdfg.out_edges(curr_state): + traversal_q.append(e.dst) + return states + + +def find_merge_state(sdfg: SDFG, state: SDFGState): + """ + Adapted from ``cfg.stateorder_topological_sort``. + """ + from dace.sdfg.analysis import cfg + + # Get parent states + ptree = cfg.state_parent_tree(sdfg) + + # Annotate branches + adf = cfg.acyclic_dominance_frontier(sdfg) + oedges = sdfg.out_edges(state) + # Skip if not branch + if len(oedges) <= 1: + return + # Skip if natural loop + if len(oedges) == 2 and ((ptree[oedges[0].dst] == state and ptree[oedges[1].dst] != state) or + (ptree[oedges[1].dst] == state and ptree[oedges[0].dst] != state)): + return + + # If branch without else (adf of one successor is equal to the other) + if len(oedges) == 2: + if {oedges[0].dst} & adf[oedges[1].dst]: + return oedges[0].dst + elif {oedges[1].dst} & adf[oedges[0].dst]: + return oedges[1].dst + + # Try to obtain common DF to find merge state + common_frontier = set() + for oedge in oedges: + frontier = adf[oedge.dst] + if not frontier: + frontier = {oedge.dst} + common_frontier |= frontier + if len(common_frontier) == 1: + return next(iter(common_frontier)) + print(f'WARNING: No merge state could be detected for branch state "{state.name}".', ) + + +def symeval(val, symbols): + """ + Takes a sympy expression and substitutes its symbols according to a dict { old_symbol: new_symbol}. + + :param val: The expression we are updating. + :param symbols: Dictionary of key value pairs { old_symbol: new_symbol}. + """ + first_replacement = {pystr_to_symbolic(k): pystr_to_symbolic('__REPLSYM_' + k) for k in symbols.keys()} + second_replacement = {pystr_to_symbolic('__REPLSYM_' + k): v for k, v in symbols.items()} + return sp.simplify(val.subs(first_replacement).subs(second_replacement)) + + +def evaluate_symbols(base, new): + result = {} + for k, v in new.items(): + result[k] = symeval(v, base) + return result + + +def update_mapping(mapping, e): + update = {} + for k, v in e.data.assignments.items(): + if '[' not in k and '[' not in v: + update[k] = pystr_to_symbolic(v).subs(mapping) + mapping.update(update) + + +def update_map_iterators(map, mapping): + # update the map params and return False + # if all iterations exhausted, return True + # always increase the last one. If it is exhausted, increase the next one and so forth + map_exhausted = True + for p, range in zip(map.params[::-1], map.range[::-1]): # reversed order + curr_value = mapping[p] + if not isinstance(range[1], SymExpr): + if curr_value.subs(mapping) + range[2].subs(mapping) <= range[1].subs(mapping): + # update this value and then we are done + mapping[p] = curr_value.subs(mapping) + range[2].subs(mapping) + map_exhausted = False + break + else: + # set current param to start again and continue + mapping[p] = range[0].subs(mapping) + else: + if curr_value.subs(mapping) + range[2].subs(mapping) <= range[1].expr.subs(mapping): + # update this value and we done + mapping[p] = curr_value.subs(mapping) + range[2].subs(mapping) + map_exhausted = False + break + else: + # set current param to start again and continue + mapping[p] = range[0].subs(mapping) + return map_exhausted + + +def map_op_in(state: SDFGState, op_in_map: Dict[str, sp.Expr], entry, mapping, stack, clt, C, symbols, array_names, + decided_branches, ask_user): + # we are inside a map --> we need to iterate over the map range and check each memory access. + for p, range in zip(entry.map.params, entry.map.range): + # map each map iteration variable to its start + mapping[p] = range[0].subs(mapping) + map_misses = 0 + while True: + # do analysis of map contents + map_misses += scope_op_in(state, op_in_map, mapping, stack, clt, C, symbols, array_names, decided_branches, + ask_user, entry) + + if update_map_iterators(entry.map, mapping): + break + return map_misses + + +def scope_op_in(state: SDFGState, + op_in_map: Dict[str, sp.Expr], + mapping, + stack: AccessStack, + clt: CacheLineTracker, + C, + symbols, + array_names, + decided_branches, + ask_user, + entry=None): + """ + Computes the operational intensity of a single scope (scope is either an SDFG state or a map scope). + + :param sdfg: The SDFG to analyze. + :param op_in_map: Dictionary storing the resulting operational intensity for each SDFG element. + :param mapping: Mapping of SDFG symbols to their current values. + :param stack: The stack used to track the stack distances. + :param clt: The current CacheLineTracker object mapping data container accesses to cache line ids. + :param C: Cache size in bytes. + :param symbols: A dictionary mapping local nested SDFG symbols to global symbols. + :param array_names: A dictionary mapping local nested SDFG array names to global array names. + :param decided_branches: Dictionary keeping track of user's decisions on which branches to analyze (if ask_user is True). + :param ask_user: If True, the user has to decide which branch to analyze in case it cannot be determined automatically. If False, + all branches get analyzed. + :param entry: If None, the whole state gets analyzed. Else, only the scope starting at this entry node is analyzed. + """ + + # find the number of cache misses for each node. + # for maps and nested SDFG, we do it recursively. + scope_misses = 0 + scope_nodes = state.scope_children()[entry] + for node in scope_nodes: + if isinstance(node, nd.EntryNode): + # If the scope contains an entry node, we need to recursively analyze the sub-scope of the entry node first. + map_misses = map_op_in(state, op_in_map, node, mapping, stack, clt, C, symbols, array_names, + decided_branches, ask_user) + + update_map(op_in_map, get_uuid(node, state), map_misses) + scope_misses += map_misses + elif isinstance(node, nd.Tasklet): + tasklet_misses = 0 + # analyze the memory accesses of this tasklet and whether they hit in cache or not + for e in state.in_edges(node) + state.out_edges(node): + if e.data.data in clt.array_info or (e.data.data in array_names + and array_names[e.data.data] in clt.array_info): + line_id = clt.cache_line_id( + e.data.data if e.data.data not in array_names else array_names[e.data.data], + [x[0].subs(mapping) for x in e.data.subset.ranges], mapping) + + line_id = int(line_id.subs(mapping)) + dist = stack.touch(line_id) + tasklet_misses += 1 if dist >= C or dist == -1 else 0 + + scope_misses += tasklet_misses + # a tasklet can get passed multiple times... we report the average misses in the end + # op_in_map is a tuple for each element consisting of (num_total_misses, accesses). + # num_total_misses / accesses then gives the average misses + update_map(op_in_map, get_uuid(node, state), tasklet_misses) + elif isinstance(node, nd.NestedSDFG): + + # keep track of nested symbols: "symbols" maps local nested SDFG symbols to global symbols. + # We only want global symbols in our final expressions. + nested_syms = {} + nested_syms.update(symbols) + nested_syms.update(evaluate_symbols(symbols, node.symbol_mapping)) + + # Handle nested arrays: Inside the nested SDFG, an array could have a different name, even + # though the same array is referenced + nested_array_names = {} + nested_array_names.update(array_names) + # for each conncector to the nested SDFG, add a pair (connector_name, incoming array name) to the dict + for e in state.in_edges(node): + nested_array_names[e.dst_conn] = e.data.data + for e in state.out_edges(node): + nested_array_names[e.src_conn] = e.data.data + # Nested SDFGs are recursively analyzed first. + nsdfg_misses = sdfg_op_in(node.sdfg, op_in_map, mapping, stack, clt, C, nested_syms, nested_array_names, + decided_branches, ask_user) + + scope_misses += nsdfg_misses + update_map(op_in_map, get_uuid(node, state), nsdfg_misses) + elif isinstance(node, nd.LibraryNode): + # add a symbol to the top level sdfg, such that the user can define it in the extension + top_level_sdfg = state.parent + try: + top_level_sdfg.add_symbol(f'{node.name}_misses', dtypes.int64) + except FileExistsError: + pass + lib_node_misses = sp.Symbol(f'{node.name}_misses', positive=True) + lib_node_misses = lib_node_misses.subs(mapping) + scope_misses += lib_node_misses + update_map(op_in_map, get_uuid(node, state), lib_node_misses) + if entry is None: + # if entry is none this means that we are analyzing the whole state --> save number of misses in get_uuid(state) + update_map(op_in_map, get_uuid(state), scope_misses, average=False) + return scope_misses + + +def sdfg_op_in(sdfg: SDFG, + op_in_map: Dict[str, Tuple[sp.Expr, sp.Expr]], + mapping, + stack: AccessStack, + clt: CacheLineTracker, + C, + symbols, + array_names, + decided_branches, + ask_user, + start=None, + end=None): + """ + Computes the operational intensity of the input SDFG. + + :param sdfg: The SDFG to analyze. + :param op_in_map: Dictionary storing the resulting operational intensity for each SDFG element. + :param mapping: Mapping of SDFG symbols to their current values. + :param stack: The stack used to track the stack distances. + :param clt: The current CacheLineTracker object mapping data container accesses to cache line ids. + :param C: Cache size in bytes. + :param symbols: A dictionary mapping local nested SDFG symbols to global symbols. + :param array_names: A dictionary mapping local nested SDFG array names to global array names. + :param decided_branches: Dictionary keeping track of user's decisions on which branches to analyze (if ask_user is True). + :param ask_user: If True, the user has to decide which branch to analyze in case it cannot be determined automatically. If False, + all branches get analyzed. + :param start: The start state of the SDFG traversal. If None, the SDFG's normal start state is used. + :param end: The end state of the SDFG traversal. If None, the whole SDFG is traversed. + """ + + if start is None: + # add this SDFG's arrays to the cache line tracker + for name, arr in sdfg.arrays.items(): + if isinstance(arr, Array): + if name in array_names: + name = array_names[name] + clt.add_array(name, arr, mapping) + # start traversal at SDFG's start state + curr_state = sdfg.start_state + else: + curr_state = start + + total_misses = 0 + # traverse this SDFG's states + while True: + total_misses += scope_op_in(curr_state, op_in_map, mapping, stack, clt, C, symbols, array_names, + decided_branches, ask_user) + + if len(sdfg.out_edges(curr_state)) == 0: + # we reached an end state --> stop + break + else: + # take first edge with True condition + found = False + for e in sdfg.out_edges(curr_state): + if e.data.is_unconditional() or e.data.condition_sympy().subs(mapping) == True: + # save e's assignments in mapping and update curr_state + # replace values first with mapping, then update mapping + try: + update_mapping(mapping, e) + except: + print('\nWARNING: Uncommon assignment detected on InterstateEdge (e.g. bitwise operators).' + 'Analysis may give wrong results.') + print(e.data.assignments, 'was the edge\'s assignments.') + curr_state = e.dst + found = True + break + if not found: + # We need to check if we are in an implicit end state (i.e. all outgoing edge conditions evaluate to False) + all_false = True + for e in sdfg.out_edges(curr_state): + if e.data.condition_sympy().subs(mapping) != False: + all_false = False + if all_false: + break + + if curr_state in decided_branches: + # if the user already decided this branch in a previous iteration, take the same branch again. + e = decided_branches[curr_state] + + update_mapping(mapping, e) + curr_state = e.dst + else: + # we cannot determine which branch to take --> check if both contain work + merge_state = find_merge_state(sdfg, curr_state) + next_edge_candidates = [] + for e in sdfg.out_edges(curr_state): + states = find_states_between(sdfg, e.dst, merge_state) + curr_work = mem_accesses_on_path(states) + if sp.sympify(curr_work).subs(mapping) > 0: + next_edge_candidates.append(e) + + if len(next_edge_candidates) == 1: + e = next_edge_candidates[0] + update_mapping(mapping, e) + decided_branches[curr_state] = e + curr_state = e.dst + else: + if ask_user: + edges = sdfg.out_edges(curr_state) + print(f'\n\nWhich branch to take at {curr_state.name}') + for i in range(len(edges)): + print(f'({i}) for edge to state {edges[i].dst.name}') + print(edges[i].dst._read_and_write_sets()) + print('merge state is named ', merge_state) + chosen = int(input('Choose an option from above: ')) + e = edges[chosen] + update_mapping(mapping, e) + decided_branches[curr_state] = e + curr_state = e.dst + print(2 * '\n') + else: + final_e = next_edge_candidates.pop() + for e in next_edge_candidates: + + # copy the state of the analysis + curr_mapping = dict(mapping) + update_mapping(curr_mapping, e) + curr_stack = stack.copy() + curr_clt = clt.copy() + curr_symbols = dict(symbols) + curr_array_names = dict(array_names) + + curr_state = e.dst + # walk down this branch until merge_state + sdfg_op_in(sdfg, op_in_map, curr_mapping, curr_stack, curr_clt, C, curr_symbols, + curr_array_names, decided_branches, ask_user, curr_state, merge_state) + + update_mapping(mapping, final_e) + curr_state = final_e.dst + if curr_state == end: + break + + if end is None: + # only update if we were actually analyzing a whole sdfg (not just start to end state) + update_map(op_in_map, get_uuid(sdfg), total_misses, average=False) + return total_misses + + +def analyze_sdfg_op_in(sdfg: SDFG, + op_in_map: Dict[str, sp.Expr], + C, + L, + assumptions, + generate_plots=False, + stringify=False, + test_set_size=3, + ask_user=False): + """ + Computes the operational intensity of the input SDFG. + + :param sdfg: The SDFG to analyze. + :param op_in_map: Dictionary storing the resulting operational intensity for each SDFG element. + :param C: Cache size in bytes. + :param L: Cache line size in bytes. + :param assumptions: Dictionary mapping SDFG symbols to concrete values, e.g. {'N': 8}. At most one symbol might be associated + with a range of (start, stop, step), e.g. {'M' : '2,10,1'}. + :param generate_plots: If True (and there is a range symbol N), a plot showing the operational intensity as a function of N + for the whole SDFG. + :param stringify: If True, the final operational intensity values will be converted to strings. + :param test_set_size: The size of the test set when testing the goodness of fit. + :param ask_user: If True, the user has to decide which branch to analyze in case it cannot be determined automatically. If False, + all branches get analyzed. + """ + + # from now on we take C as the number of lines that fit into cache + C = C // L + + sdfg = deepcopy(sdfg) + # apply SSA pass + pipeline = FixedPointPipeline([StrictSymbolSSA()]) + pipeline.apply_pass(sdfg, {}) + + # check if all symbols are concretized (at most one can be associated with a range) + undefined_symbols = set() + range_symbol = {} + for sym in sdfg.free_symbols: + if sym not in assumptions: + undefined_symbols.add(sym) + elif isinstance(assumptions[sym], str): + range_symbol[sym] = SymbolRange(int(x) for x in assumptions[sym].split(',')) + del assumptions[sym] + + work_map = {} + assumptions_list = [f'{x}=={y}' for x, y in assumptions.items()] + analyze_sdfg(sdfg, work_map, get_tasklet_work, assumptions_list) + + if len(undefined_symbols) > 0: + raise Exception( + f'Undefined symbols detected: {undefined_symbols}. Please specify a value for all free symbols of the SDFG.' + ) + else: + # all symbols defined + if len(range_symbol) > 1: + raise Exception('More than one range symbol detected! Only one range symbol allowed.') + elif len(range_symbol) == 0: + # all symbols are concretized --> run normal op_in analysis with concretized symbols + sdfg.specialize(assumptions) + mapping = {} + mapping.update(assumptions) + + stack = AccessStack(C) + clt = CacheLineTracker(L) + + sdfg_op_in(sdfg, op_in_map, mapping, stack, clt, C, {}, {}, {}, ask_user) + # compute bytes + for k, v in op_in_map.items(): + op_in_map[k] = v[0] / v[1] * L + calculate_op_in(op_in_map, work_map, stringify) + else: + # we have one variable symbol + + # decided_branches: Dict[SDFGState, InterstateEdge] = {} + cache_miss_measurements = {} + work_measurements = [] + t = 0 + while True: + new_val = False + for sym, r in range_symbol.items(): + val = r.next() + if val > -1: + new_val = True + assumptions[sym] = val + elif t < 3: + # now we sample test set + t += 1 + assumptions[sym] = r.max_value() + t * 3 + new_val = True + if not new_val: + break + + curr_op_in_map = {} + mapping = {} + mapping.update(assumptions) + stack = AccessStack(C) + clt = CacheLineTracker(L) + sdfg_op_in(sdfg, curr_op_in_map, mapping, stack, clt, C, {}, {}, {}, ask_user) + + # compute average cache misses + for k, v in curr_op_in_map.items(): + curr_op_in_map[k] = v[0] / v[1] + + # save cache misses + curr_cache_misses = dict(curr_op_in_map) + + work_measurements.append(work_map[get_uuid(sdfg)][0].subs(assumptions)) + # put curr values in cache_miss_measurements + for k, v in curr_cache_misses.items(): + if k in cache_miss_measurements: + cache_miss_measurements[k].append(v) + else: + cache_miss_measurements[k] = [v] + + symbol_name = next(iter(range_symbol.keys())) + x_values = range_symbol[symbol_name].to_list() + x_values.extend([r.max_value() + t * 3 for t in range(1, test_set_size + 1)]) + + sympy_fs = {} + for k, v in cache_miss_measurements.items(): + final_f, sympy_f, r_s = fit_curve(x_values[:-test_set_size], v[:-test_set_size], symbol_name) + op_in_map[k] = sp.simplify(sympy_f * L) + sympy_fs[k] = sympy_f + if k == get_uuid(sdfg): + # compute MAPE on total SDFG + mape = compute_mape(final_f, x_values[-test_set_size:], v[-test_set_size:], test_set_size) + if mape > 0.2: + print('High MAPE detected:', mape) + print('It is suggested to generate plots and analyze those.') + print('R^2 is:', r_s) + print('A hight R^2 (i.e. close to 1) suggests that we are fitting the test data well.') + print('This combined with high MAPE tells us that our test data does not generalize.') + calculate_op_in(op_in_map, work_map, not generate_plots) + + if generate_plots: + # plot results for the whole SDFG + plot(x_values, work_map, cache_miss_measurements, op_in_map, symbol_name, C, L, sympy_fs, + get_uuid(sdfg), sdfg.name) + + if stringify: + for k, v in op_in_map.items(): + op_in_map[k] = str(v) + + +################################################################################ +# Utility functions for running the analysis from the command line ############# +################################################################################ + + +def main() -> None: + + parser = argparse.ArgumentParser('operational_intensity', + usage='python operational_intensity.py [-h] filename', + description='Analyze the operational_intensity of an SDFG.') + + parser.add_argument('filename', type=str, help='The SDFG file to analyze.') + parser.add_argument('--C', type=str, help='Cache size in bytes') + parser.add_argument('--L', type=str, help='Cache line size in bytes') + + parser.add_argument('--assume', nargs='*', help='Collect assumptions about symbols, e.g. x>0 x>y y==5') + args = parser.parse_args() + + args = parser.parse_args() + if not os.path.exists(args.filename): + print(args.filename, 'does not exist.') + exit() + + sdfg = SDFG.from_file(args.filename) + op_in_map = {} + if args.assume is None: + args.assume = [] + + assumptions = {} + for x in args.assume: + a, b = x.split('==') + if b.isdigit(): + assumptions[a] = int(b) + else: + assumptions[a] = b + print(assumptions) + analyze_sdfg_op_in(sdfg, op_in_map, int(args.C), int(args.L), assumptions) + + result_whole_sdfg = op_in_map[get_uuid(sdfg)] + + print(80 * '-') + print("Operational Intensity:\t", result_whole_sdfg) + print(80 * '-') + + +if __name__ == '__main__': + main() diff --git a/dace/sdfg/work_depth_analysis/work_depth.py b/dace/sdfg/performance_evaluation/work_depth.py similarity index 91% rename from dace/sdfg/work_depth_analysis/work_depth.py rename to dace/sdfg/performance_evaluation/work_depth.py index 3549e86a20..c1277b1c4e 100644 --- a/dace/sdfg/work_depth_analysis/work_depth.py +++ b/dace/sdfg/performance_evaluation/work_depth.py @@ -7,7 +7,7 @@ from dace.sdfg import nodes as nd, propagation, InterstateEdge from dace import SDFG, SDFGState, dtypes from dace.subsets import Range -from typing import Tuple, Dict +from typing import List, Tuple, Dict import os import sympy as sp from copy import deepcopy @@ -18,8 +18,8 @@ import astunparse import warnings -from dace.sdfg.work_depth_analysis.helpers import get_uuid, find_loop_guards_tails_exits -from dace.sdfg.work_depth_analysis.assumptions import parse_assumptions +from dace.sdfg.performance_evaluation.helpers import LoopExtractionError, get_uuid, find_loop_guards_tails_exits +from dace.sdfg.performance_evaluation.assumptions import parse_assumptions from dace.transformation.passes.symbol_ssa import StrictSymbolSSA from dace.transformation.pass_pipeline import FixedPointPipeline @@ -70,7 +70,8 @@ def count_work_matmul(node, symbols, state): if len(C_memlet.data.subset) == 3: result *= symeval(C_memlet.data.subset.size()[0], symbols) # M*N - result *= symeval(C_memlet.data.subset.size()[-2], symbols) + # we need the if else, since C_memlet is one dimensional in case of matrix vector product + result *= 1 if len(C_memlet.data.subset.size()) < 2 else symeval(C_memlet.data.subset.size()[-2], symbols) result *= symeval(C_memlet.data.subset.size()[-1], symbols) # K result *= symeval(A_memlet.data.subset.size()[-1], symbols) @@ -81,7 +82,7 @@ def count_depth_matmul(node, symbols, state): # optimal depth of a matrix multiplication is O(log(size of shared dimension)): A_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_a') size_shared_dimension = symeval(A_memlet.data.subset.size()[-1], symbols) - return bigo(sp.log(size_shared_dimension)) + return sp.log(size_shared_dimension) def count_work_reduce(node, symbols, state): @@ -101,7 +102,7 @@ def count_work_reduce(node, symbols, state): def count_depth_reduce(node, symbols, state): # optimal depth of reduction is log of the work - return bigo(sp.log(count_work_reduce(node, symbols, state))) + return sp.log(count_work_reduce(node, symbols, state)) LIBNODES_TO_WORK = { @@ -116,11 +117,11 @@ def count_depth_reduce(node, symbols, state): Reduce: count_depth_reduce, } -bigo = sp.Function('bigo') PYFUNC_TO_ARITHMETICS = { 'float': 0, 'dace.float64': 0, 'dace.int64': 0, + 'dace.complex128': 0, 'math.exp': 1, 'exp': 1, 'math.tanh': 1, @@ -129,7 +130,7 @@ def count_depth_reduce(node, symbols, state): 'tanh': 1, 'math.sqrt': 1, 'sqrt': 1, - 'atan2:': 1, + 'atan2': 1, 'min': 0, 'max': 0, 'ceiling': 0, @@ -223,7 +224,6 @@ def visit_While(self, node): def count_depth_code(code): - # so far this is the same as the work counter, since work = depth for each tasklet, as we can't assume any parallelism ctr = ArithmeticCounter() if isinstance(code, (tuple, list)): for stmt in code: @@ -287,9 +287,11 @@ def update_value_map(old, new): def do_initial_subs(w, d, eq, subs1): """ - Calls subs three times for the give (w)ork and (d)epth values. + Calls subs three times for the given (w)ork and (d)epth values. """ - return sp.simplify(w.subs(eq[0]).subs(eq[1]).subs(subs1)), sp.simplify(d.subs(eq[0]).subs(eq[1]).subs(subs1)) + result = sp.simplify(sp.sympify(w).subs(eq[0]).subs(eq[1]).subs(subs1)), sp.simplify( + sp.sympify(d).subs(eq[0]).subs(eq[1]).subs(subs1)) + return result def sdfg_work_depth(sdfg: SDFG, @@ -326,10 +328,12 @@ def sdfg_work_depth(sdfg: SDFG, detailed_analysis) # Substitutions for state_work and state_depth already performed, but state.executions needs to be subs'd now. - state_work = sp.simplify(state_work * - state.executions.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) - state_depth = sp.simplify(state_depth * - state.executions.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) + state_work = sp.simplify( + state_work.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1) * + state.executions.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) + state_depth = sp.simplify( + state_depth.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1) * + state.executions.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) state_works[state], state_depths[state] = state_work, state_depth w_d_map[get_uuid(state)] = (state_works[state], state_depths[state]) @@ -340,7 +344,21 @@ def sdfg_work_depth(sdfg: SDFG, # Additionally, construct a dummy exit state and connect every state that has no outgoing edges to it. # identify all loops in the SDFG - nodes_oNodes_exits = find_loop_guards_tails_exits(sdfg._nx) + try: + nodes_oNodes_exits = find_loop_guards_tails_exits(sdfg._nx) + except LoopExtractionError: + # If loop detection fails, we cannot make proper propagation. + print('Analysis failed since not all loops got detected. It may help to use more structured loop constructs.' + + ' The analysis per state remains correct, but no SDFG-wide analysis can be performed.') + sdfg_result = (sp.oo, sp.oo) + w_d_map[get_uuid(sdfg)] = sdfg_result + + for k, (v_w, v_d) in w_d_map.items(): + # The symeval replaces nested SDFG symbols with their global counterparts. + v_w = symeval(v_w, symbols) + v_d = symeval(v_d, symbols) + w_d_map[k] = (v_w, v_d) + return sdfg_result # Now we need to go over each triple (node, oNode, exits). For each triple, we # - remove edge (oNode, node), i.e. the backward edge @@ -392,8 +410,7 @@ def sdfg_work_depth(sdfg: SDFG, else: state_value_map[state] = value_map - # ignore assignments such as tmp=x[0], as those do not give much information. - value_map = {k: v for k, v in state_value_map[state].items() if '[' not in k and '[' not in v} + value_map = {pystr_to_symbolic(k): pystr_to_symbolic(v) for k, v in state_value_map[state].items()} n_depth = sp.simplify((depth + state_depths[state]).subs(value_map)) n_work = sp.simplify((work + state_works[state]).subs(value_map)) @@ -458,10 +475,19 @@ def sdfg_work_depth(sdfg: SDFG, new_cse_stack.append((work_map[state], depth_map[state])) # same for value_map new_value_map = dict(state_value_map[state]) - new_value_map.update({sp.Symbol(k): sp.Symbol(v) for k, v in oedge.data.assignments.items()}) + new_value_map.update({ + pystr_to_symbolic(k): + pystr_to_symbolic(v).subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1) + for k, v in oedge.data.assignments.items() + }) traversal_q.append((oedge.dst, 0, 0, oedge, new_cond_stack, new_cse_stack, new_value_map)) else: - value_map.update(oedge.data.assignments) + # value_map.update(oedge.data.assignments) + value_map.update({ + pystr_to_symbolic(k): + pystr_to_symbolic(v).subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1) + for k, v in oedge.data.assignments.items() + }) traversal_q.append((oedge.dst, depth_map[state], work_map[state], oedge, condition_stack, common_subexpr_stack, value_map)) @@ -471,11 +497,17 @@ def sdfg_work_depth(sdfg: SDFG, except KeyError: # If we get a KeyError above, this means that the traversal never reached the dummy_exit state. # This happens if the loops were not properly detected and broken. - raise Exception( + raise LoopExtractionError( 'Analysis failed, since not all loops got detected. It may help to use more structured loop constructs.') sdfg_result = (max_work, max_depth) w_d_map[get_uuid(sdfg)] = sdfg_result + + for k, (v_w, v_d) in w_d_map.items(): + # The symeval replaces nested SDFG symbols with their global counterparts. + v_w = symeval(v_w, symbols) + v_d = symeval(v_d, symbols) + w_d_map[k] = (v_w, v_d) return sdfg_result @@ -531,9 +563,6 @@ def scope_work_depth( # add up work for whole state, but also save work for this sub-scope scope in w_d_map work += s_work w_d_map[get_uuid(node, state)] = (s_work, s_depth) - elif node == scope_exit: - # don't do anything for exit nodes, everthing handled already in the corresponding entry node. - pass elif isinstance(node, nd.Tasklet): # add up work for whole state, but also save work for this node in w_d_map t_work, t_depth = analyze_tasklet(node, state) @@ -567,9 +596,14 @@ def scope_work_depth( # TODO: This symbol should now appear in the VS code extension in the SDFG analysis tab, # such that the user can define its value. But it doesn't... # How to achieve this? - top_level_sdfg.add_symbol(f'{node.name}_work', dtypes.int64) + try: + top_level_sdfg.add_symbol(f'{node.name}_work', dtypes.int64) + except FileExistsError: + # Such a library node was already encountered by the analysis. + # Hence, we don't need to add anyting. + pass lib_node_work = sp.Symbol(f'{node.name}_work', positive=True) - lib_node_depth = sp.sympify(-1) # not analyzed + lib_node_depth = sp.sympify(-1) if analyze_tasklet != get_tasklet_work: # we are analyzing depth try: @@ -704,7 +738,7 @@ def state_work_depth(state: SDFGState, def analyze_sdfg(sdfg: SDFG, w_d_map: Dict[str, sp.Expr], analyze_tasklet, - assumptions: [str], + assumptions: List[str], detailed_analysis: bool = False) -> None: """ Analyze a given SDFG. We can either analyze work, work and depth or average parallelism. @@ -823,7 +857,7 @@ def main() -> None: elif args.analyze == 'work': print("Work:\t", result_whole_sdfg) elif args.analyze == 'avgPar': - print("Average Parallelism:\t", result_whole_sdfg) + print("Average Parallelism:\t", sp.N(result_whole_sdfg)) print(80 * '-') diff --git a/tests/sdfg/operational_intensity_test.py b/tests/sdfg/operational_intensity_test.py new file mode 100644 index 0000000000..4406ecb0b8 --- /dev/null +++ b/tests/sdfg/operational_intensity_test.py @@ -0,0 +1,148 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +""" Contains test cases for the operational intensity analysis. """ +from typing import Dict, Tuple + +import pytest +import dace as dc +import sympy as sp +import numpy as np +from dace.sdfg.performance_evaluation.operational_intensity import analyze_sdfg_op_in +from dace.sdfg.performance_evaluation.helpers import get_uuid +from dace.frontend.python.parser import DaceProgram + +from math import isclose + +N = dc.symbol('N') +M = dc.symbol('M') +K = dc.symbol('K') + +TILE_SIZE = dc.symbol('TILE_SIZE') + + +@dc.program +def single_map64(x: dc.float64[N], y: dc.float64[N], z: dc.float64[N]): + z[:] = x + y + # does N work, loads 3*N elements of 8 bytes + # --> op_in should be N / 3*8*N = 1/24 (no reuse) assuming L divides N + + +@dc.program +def single_map16(x: dc.float16[N], y: dc.float16[N], z: dc.float16[N]): + z[:] = x + y + # does N work, loads 3*N elements of 2 bytes + # --> op_in should be N / 3*2*N = 1/6 (no reuse) assuming L divides N + + +@dc.program +def single_for_loop(x: dc.float64[N], y: dc.float64[N]): + for i in range(N): + x[i] += y[i] + # N work, 2*N*8 bytes loaded + # --> 1/16 op in + + +@dc.program +def if_else(x: dc.int64[100], sum: dc.int64[1]): + if x[10] > 50: + for i in range(100): + sum += x[i] + if x[0] > 3: + for i in range(100): + sum += x[i] + # no else --> simply analyze the ifs. if cache big enough, everything is reused + + +@dc.program +def unaligned_for_loop(x: dc.float32[100], sum: dc.int64[1]): + for i in range(17, 53): + sum += x[i] + + +@dc.program +def sequential_maps(x: dc.float64[N], y: dc.float64[N], z: dc.float64[N]): + z[:] = x + y + z[:] *= 2 + z[:] += x + # does N work, loads 3*N elements of 8 bytes + # --> op_in should be N / 3*8*N = 1/24 (no reuse) assuming L divides N + + +@dc.program +def nested_reuse(x: dc.float64[N], y: dc.float64[N], z: dc.float64[N], result: dc.float64[1]): + # load x, y and z + z[:] = x + y + result[0] = np.sum(z) + # tests whether the access to z from the nested SDFG correspond with the prior accesses + # to z outside of the nested SDFG. + + +@dc.program +def mmm(x: dc.float64[N, N], y: dc.float64[N, N], z: dc.float64[N, N]): + for n, k, m in dc.map[0:N, 0:N, 0:N]: + z[n, k] += x[n, m] * y[m, k] + + +@dc.program +def tiled_mmm(x: dc.float64[N, N], y: dc.float64[N, N], z: dc.float64[N, N]): + for n_TILE, k_TILE, m_TILE in dc.map[0:N:TILE_SIZE, 0:N:TILE_SIZE, 0:N:TILE_SIZE]: + for n, k, m in dc.map[n_TILE:n_TILE + TILE_SIZE, k_TILE:k_TILE + TILE_SIZE, m_TILE:m_TILE + TILE_SIZE]: + z[n, k] += x[n, m] * y[m, k] + + +@dc.program +def tiled_mmm_32(x: dc.float32[N, N], y: dc.float32[N, N], z: dc.float32[N, N]): + for n_TILE, k_TILE, m_TILE in dc.map[0:N:TILE_SIZE, 0:N:TILE_SIZE, 0:N:TILE_SIZE]: + for n, k, m in dc.map[n_TILE:n_TILE + TILE_SIZE, k_TILE:k_TILE + TILE_SIZE, m_TILE:m_TILE + TILE_SIZE]: + z[n, k] += x[n, m] * y[m, k] + + +@dc.program +def reduction_library_node(x: dc.float64[N]): + return np.sum(x) + + +#(sdfg, c, l, assumptions, expected_result) +test_cases: Dict[str, Tuple[DaceProgram, int, int, Dict[str, int], dc.symbolic.SymbolicType]] = { + 'single_map64_even': (single_map64, 64 * 64, 64, { 'N': 512 }, 1 / 24), + 'single_map16_even': (single_map16, 64 * 64, 64, { 'N': 512 }, 1 / 6), + # now num_elements_on_single_cache_line does not divie N anymore + # -->513 work, 520 elements loaded --> 513 / (520*8*3) + 'single_map64_uneven': (single_map64, 64 * 64, 64, { 'N': 513 }, 513 / (3 * 8 * 520)), + 'sequential_maps': (sequential_maps, 1024, 3 * 8, { 'N': 29 }, 87 / (90 * 8)), + # smaller cache --> only two arrays fit --> x loaded twice now + 'sequential_maps_small': (sequential_maps, 6, 3 * 8, { 'N': 7 }, 21 / (13 * 3 * 8)), + 'nested_reuse': (nested_reuse, 1024, 64, { 'N': 1024 }, 2048 / (3 * 1024 * 8 + 128)), + 'mmm': (mmm, 20, 16, { 'N': 24 }, (2 * 24**3) / ((36 * 24**2 + 24 * 12) * 16)), + 'tiled_mmm': (tiled_mmm, 20, 16, { 'N': 24, 'TILE_SIZE': 4 }, (2 * 24**3) / (16 * 24 * 6**3)), + 'tiled_mmm_32': (tiled_mmm_32, 10, 16, { 'N': 24, 'TILE_SIZE': 4 }, (2 * 24**3) / (16 * 12 * 6**3)), + 'reduction_library_node': (reduction_library_node, 1024, 64, { 'N': 128 }, + 128.0 / (dc.symbol('Reduce_misses') * 64.0 + 64.0)), +} + + +@pytest.mark.parametrize('test_name', list(test_cases.keys())) +def test_operational_intensity(test_name: str): + test, c, l, assumptions, correct = test_cases[test_name] + op_in_map: Dict[str, sp.Expr] = {} + sdfg = test.to_sdfg() + if test_name == 'nested_reuse': + sdfg.expand_library_nodes() + if test_name in ['sequential_maps', 'sequential_maps_small', 'nested_reuse', 'mmm', 'tiled_mmm', 'tiled_mmm_32']: + sdfg.simplify() + analyze_sdfg_op_in(sdfg, op_in_map, c * l, l, assumptions) + res = (op_in_map[get_uuid(sdfg)]) + if test_name == 'reduction_library_node': + # substitue each symbol without assumptions. + # We do this since sp.Symbol('N') == Sp.Symbol('N', positive=True) --> False. + reps = {s: sp.Symbol(s.name) for s in res.free_symbols} + res = res.subs(reps) + reps = {s: sp.Symbol(s.name) for s in sp.sympify(correct).free_symbols} + correct = sp.sympify(correct).subs(reps) + assert correct == res + else: + assert isclose(correct, res) + + +if __name__ == '__main__': + for test_name in test_cases.keys(): + test_operational_intensity(test_name) diff --git a/tests/sdfg/work_depth_test.py b/tests/sdfg/work_depth_test.py new file mode 100644 index 0000000000..e677cca752 --- /dev/null +++ b/tests/sdfg/work_depth_test.py @@ -0,0 +1,330 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +""" Contains test cases for the work depth analysis. """ +from typing import Dict, List, Tuple + +import pytest +import dace as dc +from dace import symbolic +from dace.frontend.python.parser import DaceProgram +from dace.sdfg.performance_evaluation.work_depth import (analyze_sdfg, get_tasklet_work_depth, get_tasklet_avg_par, + parse_assumptions) +from dace.sdfg.performance_evaluation.helpers import get_uuid +from dace.sdfg.performance_evaluation.assumptions import ContradictingAssumptions +import sympy as sp +import numpy as np + +from dace.transformation.interstate import NestSDFG +from dace.transformation.dataflow import MapExpansion + +from pytest import raises + +N = dc.symbol('N') +M = dc.symbol('M') +K = dc.symbol('K') + + +@dc.program +def single_map(x: dc.float64[N], y: dc.float64[N], z: dc.float64[N]): + z[:] = x + y + + +@dc.program +def single_for_loop(x: dc.float64[N], y: dc.float64[N]): + for i in range(N): + x[i] += y[i] + + +@dc.program +def if_else(x: dc.int64[1000], y: dc.int64[1000], z: dc.int64[1000], sum: dc.int64[1]): + if x[10] > 50: + z[:] = x + y # 1000 work, 1 depth + else: + for i in range(100): # 100 work, 100 depth + sum += x[i] + + +@dc.program +def if_else_sym(x: dc.int64[N], y: dc.int64[N], z: dc.int64[N], sum: dc.int64[1]): + if x[10] > 50: + z[:] = x + y # N work, 1 depth + else: + for i in range(K): # K work, K depth + sum += x[i] + + +@dc.program +def nested_sdfg(x: dc.float64[N], y: dc.float64[N], z: dc.float64[N]): + single_map(x, y, z) + single_for_loop(x, y) + + +@dc.program +def nested_maps(x: dc.float64[N, M], y: dc.float64[N, M], z: dc.float64[N, M]): + z[:, :] = x + y + + +@dc.program +def nested_for_loops(x: dc.float64[N], y: dc.float64[K]): + for i in range(N): + for j in range(K): + x[i] += y[j] + + +@dc.program +def nested_if_else(x: dc.int64[N], y: dc.int64[N], z: dc.int64[N], sum: dc.int64[1]): + if x[10] > 50: + if x[9] > 40: + z[:] = x + y # N work, 1 depth + z[:] += 2 * x # 2*N work, 2 depth --> total outer if: 3*N work, 3 depth + else: + if y[9] > 30: + for i in range(K): + sum += x[i] # K work, K depth + else: + for j in range(M): + sum += x[j] # M work, M depth + z[:] = x + y # N work, depth 1 --> total inner else: M+N work, M+1 depth + # --> total outer else: Max(K, M+N) work, Max(K, M+1) depth + # --> total over both branches: Max(K, M+N, 3*N) work, Max(K, M+1, 3) depth + + +@dc.program +def max_of_positive_symbol(x: dc.float64[N]): + if x[0] > 0: + for i in range(2 * N): # work 2*N^2, depth 2*N + x += 1 + else: + for j in range(3 * N): # work 3*N^2, depth 3*N + x += 1 + # total is work 3*N^2, depth 3*N without any max + + +@dc.program +def multiple_array_sizes(x: dc.int64[N], y: dc.int64[N], z: dc.int64[N], x2: dc.int64[M], y2: dc.int64[M], + z2: dc.int64[M], x3: dc.int64[K], y3: dc.int64[K], z3: dc.int64[K]): + if x[0] > 0: + z[:] = 2 * x + y # work 2*N, depth 2 + elif x[1] > 0: + z2[:] = 2 * x2 + y2 # work 2*M + 3, depth 5 + z2[0] += 3 + z[1] + z[2] + elif x[2] > 0: + z3[:] = 2 * x3 + y3 # work 2*K, depth 2 + elif x[3] > 0: + z[:] = 3 * x + y + 1 # work 3*N, depth 3 + # --> work= Max(3*N, 2*M, 2*K) and depth = 5 + + +@dc.program +def unbounded_while_do(x: dc.float64[N]): + while x[0] < 100: + x += 1 + + +@dc.program +def unbounded_nonnegify(x: dc.float64[N]): + while x[0] < 100: + if x[1] < 42: + x += 3 * x + else: + x += x + + +@dc.program +def break_for_loop(x: dc.float64[N]): + for i in range(N): + if x[i] > 100: + break + x += 1 + + +@dc.program +def break_while_loop(x: dc.float64[N]): + while x[0] > 10: + if x[1] > 100: + break + x += 1 + + +@dc.program +def sequntial_ifs(x: dc.float64[N + 1], y: dc.float64[M + 1]): # --> cannot assume N, M to be positive + if x[0] > 5: + x[:] += 1 # N+1 work, 1 depth + else: + for i in range(M): # M work, M depth + y[i + 1] += y[i] + if M > N: + y[:N + 1] += x[:] # N+1 work, 1 depth + else: + x[:M + 1] += y[:] # M+1 work, 1 depth + # --> Work: Max(N+1, M) + Max(N+1, M+1) + # Depth: Max(1, M) + 1 + + +@dc.program +def reduction_library_node(x: dc.float64[456]): + return np.sum(x) + + +@dc.program +def reduction_library_node_symbolic(x: dc.float64[N]): + return np.sum(x) + + +@dc.program +def gemm_library_node(x: dc.float64[456, 200], y: dc.float64[200, 111], z: dc.float64[456, 111]): + z[:] = x @ y + + +@dc.program +def gemm_library_node_symbolic(x: dc.float64[M, K], y: dc.float64[K, N], z: dc.float64[M, N]): + z[:] = x @ y + + +#(sdfg, (expected_work, expected_depth)) +work_depth_test_cases: Dict[str, Tuple[DaceProgram, Tuple[symbolic.SymbolicType, symbolic.SymbolicType]]] = { + 'single_map': (single_map, (N, 1)), + 'single_for_loop': (single_for_loop, (N, N)), + 'if_else': (if_else, (1000, 100)), + 'if_else_sym': (if_else_sym, (sp.Max(K, N), sp.Max(1, K))), + 'nested_sdfg': (nested_sdfg, (2 * N, N + 1)), + 'nested_maps': (nested_maps, (M * N, 1)), + 'nested_for_loops': (nested_for_loops, (K * N, K * N)), + 'nested_if_else': (nested_if_else, (sp.Max(K, 3 * N, M + N), sp.Max(3, K, M + 1))), + 'max_of_positive_symbols': (max_of_positive_symbol, (3 * N**2, 3 * N)), + 'multiple_array_sizes': (multiple_array_sizes, (sp.Max(2 * K, 3 * N, 2 * M + 3), 5)), + 'unbounded_while_do': (unbounded_while_do, (sp.Symbol('num_execs_0_2') * N, sp.Symbol('num_execs_0_2'))), + # We get this Max(1, num_execs), since it is a do-while loop, but the num_execs symbol does not capture this. + 'unbounded_nonnegify': (unbounded_nonnegify, (2 * sp.Symbol('num_execs_0_7') * N, 2 * sp.Symbol('num_execs_0_7'))), + 'break_for_loop': (break_for_loop, (N**2, N)), + 'break_while_loop': (break_while_loop, (sp.Symbol('num_execs_0_5') * N, sp.Symbol('num_execs_0_5'))), + 'sequential_ifs': (sequntial_ifs, (sp.Max(N + 1, M) + sp.Max(N + 1, M + 1), sp.Max(1, M) + 1)), + 'reduction_library_node': (reduction_library_node, (456, sp.log(456))), + 'reduction_library_node_symbolic': (reduction_library_node_symbolic, (N, sp.log(N))), + 'gemm_library_node': (gemm_library_node, (2 * 456 * 200 * 111, sp.log(200))), + 'gemm_library_node_symbolic': (gemm_library_node_symbolic, (2 * M * K * N, sp.log(K))) +} + + +@pytest.mark.parametrize('test_name', list(work_depth_test_cases.keys())) +def test_work_depth(test_name): + if (dc.Config.get_bool('optimizer', 'automatic_simplification') == False and + test_name in ['unbounded_while_do', 'unbounded_nonnegify', 'break_while_loop']): + pytest.skip('Malformed loop when not simplifying') + test, correct = work_depth_test_cases[test_name] + w_d_map: Dict[str, sp.Expr] = {} + sdfg = test.to_sdfg() + if 'nested_sdfg' in test.name: + sdfg.apply_transformations(NestSDFG) + if 'nested_maps' in test.name: + sdfg.apply_transformations(MapExpansion) + analyze_sdfg(sdfg, w_d_map, get_tasklet_work_depth, [], False) + res = w_d_map[get_uuid(sdfg)] + # substitue each symbol without assumptions. + # We do this since sp.Symbol('N') == Sp.Symbol('N', positive=True) --> False. + reps = {s: sp.Symbol(s.name) for s in (res[0].free_symbols | res[1].free_symbols)} + res = (res[0].subs(reps), res[1].subs(reps)) + reps = { + s: sp.Symbol(s.name) + for s in (sp.sympify(correct[0]).free_symbols | sp.sympify(correct[1]).free_symbols) + } + correct = (sp.sympify(correct[0]).subs(reps), sp.sympify(correct[1]).subs(reps)) + # check result + assert correct == res + + +#(sdfg, expected_avg_par) +tests_cases_avg_par = { + 'single_map': (single_map, N), + 'single_for_loop': (single_for_loop, 1), + 'if_else': (if_else, 1), + 'nested_sdfg': (nested_sdfg, 2 * N / (N + 1)), + 'nested_maps': (nested_maps, N * M), + 'nested_for_loops': (nested_for_loops, 1), + 'max_of_positive_symbol': (max_of_positive_symbol, N), + 'unbounded_while_do': (unbounded_while_do, N), + 'unbounded_nonnegify': (unbounded_nonnegify, N), + 'break_for_loop': (break_for_loop, N), + 'break_while_loop': (break_while_loop, N), + 'reduction_library_node': (reduction_library_node, 456 / sp.log(456)), + 'reduction_library_node_symbolic': (reduction_library_node_symbolic, N / sp.log(N)), + 'gemm_library_node': (gemm_library_node, 2 * 456 * 200 * 111 / sp.log(200)), + 'gemm_library_node_symbolic': (gemm_library_node_symbolic, 2 * M * K * N / sp.log(K)), +} + +@pytest.mark.parametrize('test_name', list(tests_cases_avg_par.keys())) +def test_avg_par(test_name: str): + if (dc.Config.get_bool('optimizer', 'automatic_simplification') == False and + test_name in ['unbounded_while_do', 'unbounded_nonnegify', 'break_while_loop']): + pytest.skip('Malformed loop when not simplifying') + + test, correct = tests_cases_avg_par[test_name] + w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]] = {} + sdfg = test.to_sdfg() + if 'nested_sdfg' in test_name: + sdfg.apply_transformations(NestSDFG) + if 'nested_maps' in test_name: + sdfg.apply_transformations(MapExpansion) + analyze_sdfg(sdfg, w_d_map, get_tasklet_avg_par, [], False) + res = w_d_map[get_uuid(sdfg)][0] / w_d_map[get_uuid(sdfg)][1] + # substitue each symbol without assumptions. + # We do this since sp.Symbol('N') == Sp.Symbol('N', positive=True) --> False. + reps = {s: sp.Symbol(s.name) for s in res.free_symbols} + res = res.subs(reps) + reps = {s: sp.Symbol(s.name) for s in sp.sympify(correct).free_symbols} + correct = sp.sympify(correct).subs(reps) + # check result + assert correct == res + + +x, y, z, a = sp.symbols('x y z a') + +# (expr, assumptions, result) +assumptions_tests = [ + (sp.Max(x, y), ['x>y'], x), (sp.Max(x, y, z), ['x>y'], sp.Max(x, z)), (sp.Max(x, y), ['x==y'], y), + (sp.Max(x, 11) + sp.Max(x, 3), ['x<11'], 11 + sp.Max(x, 3)), (sp.Max(x, 11) + sp.Max(x, 3), ['x<11', + 'x>3'], 11 + x), + (sp.Max(x, 11), ['x>5', 'x>3', 'x>11'], x), (sp.Max(x, 11), ['x==y', 'x>11'], y), + (sp.Max(x, 11) + sp.Max(a, 5), ['a==b', 'b==c', 'c==x', 'a<11', 'c>7'], x + 11), + (sp.Max(x, 11) + sp.Max(a, 5), ['a==b', 'b==c', 'c==x', 'b==7'], 18), (sp.Max(x, y), ['y>x', 'y==1000'], 1000), + (sp.Max(x, y), ['y0', 'N<5', 'M>5'], M) +] + +# These assumptions should trigger the ContradictingAssumptions exception. +tests_for_exception = [['x>10', 'x<9'], ['x==y', 'x>10', 'y<9'], + ['a==b', 'b==c', 'c==d', 'd==e', 'e==f', 'x==y', 'y==z', 'z>b', 'x==5', 'd==100'], + ['x==5', 'x<4']] + + +@pytest.mark.parametrize('expr,assums,res', assumptions_tests) +def test_assumption_system(expr: sp.Expr, assums: List[str], res: sp.Expr): + equality_subs, all_subs = parse_assumptions(assums, set()) + expr = expr.subs(equality_subs[0]) + expr = expr.subs(equality_subs[1]) + for subs1, subs2 in all_subs: + expr = expr.subs(subs1) + expr = expr.subs(subs2) + assert expr == res + + +@pytest.mark.parametrize('assumptions', tests_for_exception) +def test_assumption_system_contradictions(assumptions): + # check that the Exception gets raised. + with raises(ContradictingAssumptions): + parse_assumptions(assumptions, set()) + + +if __name__ == '__main__': + for test_name in work_depth_test_cases.keys(): + test_work_depth(test_name) + + for test, correct in tests_cases_avg_par: + test_avg_par(test, correct) + + for expr, assums, res in assumptions_tests: + test_assumption_system(expr, assums, res) + + for assumptions in tests_for_exception: + test_assumption_system_contradictions(assumptions) diff --git a/tests/sdfg/work_depth_tests.py b/tests/sdfg/work_depth_tests.py deleted file mode 100644 index 05375007df..0000000000 --- a/tests/sdfg/work_depth_tests.py +++ /dev/null @@ -1,262 +0,0 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. -""" Contains test cases for the work depth analysis. """ -import dace as dc -from dace.sdfg.work_depth_analysis.work_depth import analyze_sdfg, get_tasklet_work_depth, parse_assumptions -from dace.sdfg.work_depth_analysis.helpers import get_uuid -from dace.sdfg.work_depth_analysis.assumptions import ContradictingAssumptions -import sympy as sp - -from dace.transformation.interstate import NestSDFG -from dace.transformation.dataflow import MapExpansion - -from pytest import raises - -# TODO: add tests for library nodes (e.g. reduce, matMul) -# TODO: add tests for average parallelism - -N = dc.symbol('N') -M = dc.symbol('M') -K = dc.symbol('K') - - -@dc.program -def single_map(x: dc.float64[N], y: dc.float64[N], z: dc.float64[N]): - z[:] = x + y - - -@dc.program -def single_for_loop(x: dc.float64[N], y: dc.float64[N]): - for i in range(N): - x[i] += y[i] - - -@dc.program -def if_else(x: dc.int64[1000], y: dc.int64[1000], z: dc.int64[1000], sum: dc.int64[1]): - if x[10] > 50: - z[:] = x + y # 1000 work, 1 depth - else: - for i in range(100): # 100 work, 100 depth - sum += x[i] - - -@dc.program -def if_else_sym(x: dc.int64[N], y: dc.int64[N], z: dc.int64[N], sum: dc.int64[1]): - if x[10] > 50: - z[:] = x + y # N work, 1 depth - else: - for i in range(K): # K work, K depth - sum += x[i] - - -@dc.program -def nested_sdfg(x: dc.float64[N], y: dc.float64[N], z: dc.float64[N]): - single_map(x, y, z) - single_for_loop(x, y) - - -@dc.program -def nested_maps(x: dc.float64[N, M], y: dc.float64[N, M], z: dc.float64[N, M]): - z[:, :] = x + y - - -@dc.program -def nested_for_loops(x: dc.float64[N], y: dc.float64[K]): - for i in range(N): - for j in range(K): - x[i] += y[j] - - -@dc.program -def nested_if_else(x: dc.int64[N], y: dc.int64[N], z: dc.int64[N], sum: dc.int64[1]): - if x[10] > 50: - if x[9] > 40: - z[:] = x + y # N work, 1 depth - z[:] += 2 * x # 2*N work, 2 depth --> total outer if: 3*N work, 3 depth - else: - if y[9] > 30: - for i in range(K): - sum += x[i] # K work, K depth - else: - for j in range(M): - sum += x[j] # M work, M depth - z[:] = x + y # N work, depth 1 --> total inner else: M+N work, M+1 depth - # --> total outer else: Max(K, M+N) work, Max(K, M+1) depth - # --> total over both branches: Max(K, M+N, 3*N) work, Max(K, M+1, 3) depth - - -@dc.program -def max_of_positive_symbol(x: dc.float64[N]): - if x[0] > 0: - for i in range(2 * N): # work 2*N^2, depth 2*N - x += 1 - else: - for j in range(3 * N): # work 3*N^2, depth 3*N - x += 1 - # total is work 3*N^2, depth 3*N without any max - - -@dc.program -def multiple_array_sizes(x: dc.int64[N], y: dc.int64[N], z: dc.int64[N], x2: dc.int64[M], y2: dc.int64[M], - z2: dc.int64[M], x3: dc.int64[K], y3: dc.int64[K], z3: dc.int64[K]): - if x[0] > 0: - z[:] = 2 * x + y # work 2*N, depth 2 - elif x[1] > 0: - z2[:] = 2 * x2 + y2 # work 2*M + 3, depth 5 - z2[0] += 3 + z[1] + z[2] - elif x[2] > 0: - z3[:] = 2 * x3 + y3 # work 2*K, depth 2 - elif x[3] > 0: - z[:] = 3 * x + y + 1 # work 3*N, depth 3 - # --> work= Max(3*N, 2*M, 2*K) and depth = 5 - - -@dc.program -def unbounded_while_do(x: dc.float64[N]): - while x[0] < 100: - x += 1 - - -@dc.program -def unbounded_do_while(x: dc.float64[N]): - while True: - x += 1 - if x[0] >= 100: - break - - -@dc.program -def unbounded_nonnegify(x: dc.float64[N]): - while x[0] < 100: - if x[1] < 42: - x += 3 * x - else: - x += x - - -@dc.program -def continue_for_loop(x: dc.float64[N]): - for i in range(N): - if x[i] > 100: - continue - x += 1 - - -@dc.program -def break_for_loop(x: dc.float64[N]): - for i in range(N): - if x[i] > 100: - break - x += 1 - - -@dc.program -def break_while_loop(x: dc.float64[N]): - while x[0] > 10: - if x[1] > 100: - break - x += 1 - - -@dc.program -def sequntial_ifs(x: dc.float64[N + 1], y: dc.float64[M + 1]): # --> cannot assume N, M to be positive - if x[0] > 5: - x[:] += 1 # N+1 work, 1 depth - else: - for i in range(M): # M work, M depth - y[i + 1] += y[i] - if M > N: - y[:N + 1] += x[:] # N+1 work, 1 depth - else: - x[:M + 1] += y[:] # M+1 work, 1 depth - # --> Work: Max(N+1, M) + Max(N+1, M+1) - # Depth: Max(1, M) + 1 - - -#(sdfg, (expected_work, expected_depth)) -tests_cases = [ - (single_map, (N, 1)), - (single_for_loop, (N, N)), - (if_else, (1000, 100)), - (if_else_sym, (sp.Max(K, N), sp.Max(1, K))), - (nested_sdfg, (2 * N, N + 1)), - (nested_maps, (M * N, 1)), - (nested_for_loops, (K * N, K * N)), - (nested_if_else, (sp.Max(K, 3 * N, M + N), sp.Max(3, K, M + 1))), - (max_of_positive_symbol, (3 * N**2, 3 * N)), - (multiple_array_sizes, (sp.Max(2 * K, 3 * N, 2 * M + 3), 5)), - (unbounded_while_do, (sp.Symbol('num_execs_0_2') * N, sp.Symbol('num_execs_0_2'))), - # We get this Max(1, num_execs), since it is a do-while loop, but the num_execs symbol does not capture this. - (unbounded_do_while, (sp.Max(1, sp.Symbol('num_execs_0_1')) * N, sp.Max(1, sp.Symbol('num_execs_0_1')))), - (unbounded_nonnegify, (2 * sp.Symbol('num_execs_0_7') * N, 2 * sp.Symbol('num_execs_0_7'))), - (continue_for_loop, (sp.Symbol('num_execs_0_6') * N, sp.Symbol('num_execs_0_6'))), - (break_for_loop, (N**2, N)), - (break_while_loop, (sp.Symbol('num_execs_0_5') * N, sp.Symbol('num_execs_0_5'))), - (sequntial_ifs, (sp.Max(N + 1, M) + sp.Max(N + 1, M + 1), sp.Max(1, M) + 1)) -] - - -def test_work_depth(): - for test, correct in tests_cases: - w_d_map = {} - sdfg = test.to_sdfg() - if 'nested_sdfg' in test.name: - sdfg.apply_transformations(NestSDFG) - if 'nested_maps' in test.name: - sdfg.apply_transformations(MapExpansion) - analyze_sdfg(sdfg, w_d_map, get_tasklet_work_depth, [], False) - res = w_d_map[get_uuid(sdfg)] - # substitue each symbol without assumptions. - # We do this since sp.Symbol('N') == Sp.Symbol('N', positive=True) --> False. - reps = {s: sp.Symbol(s.name) for s in (res[0].free_symbols | res[1].free_symbols)} - res = (res[0].subs(reps), res[1].subs(reps)) - reps = { - s: sp.Symbol(s.name) - for s in (sp.sympify(correct[0]).free_symbols | sp.sympify(correct[1]).free_symbols) - } - correct = (sp.sympify(correct[0]).subs(reps), sp.sympify(correct[1]).subs(reps)) - # check result - assert correct == res - - -x, y, z, a = sp.symbols('x y z a') - -# (expr, assumptions, result) -assumptions_tests = [ - (sp.Max(x, y), ['x>y'], x), (sp.Max(x, y, z), ['x>y'], sp.Max(x, z)), (sp.Max(x, y), ['x==y'], y), - (sp.Max(x, 11) + sp.Max(x, 3), ['x<11'], 11 + sp.Max(x, 3)), (sp.Max(x, 11) + sp.Max(x, 3), ['x<11', - 'x>3'], 11 + x), - (sp.Max(x, 11), ['x>5', 'x>3', 'x>11'], x), (sp.Max(x, 11), ['x==y', 'x>11'], y), - (sp.Max(x, 11) + sp.Max(a, 5), ['a==b', 'b==c', 'c==x', 'a<11', 'c>7'], x + 11), - (sp.Max(x, 11) + sp.Max(a, 5), ['a==b', 'b==c', 'c==x', 'b==7'], 18), (sp.Max(x, y), ['y>x', 'y==1000'], 1000), - (sp.Max(x, y), ['y0', 'N<5', 'M>5'], M) -] - -# These assumptions should trigger the ContradictingAssumptions exception. -tests_for_exception = [['x>10', 'x<9'], ['x==y', 'x>10', 'y<9'], - ['a==b', 'b==c', 'c==d', 'd==e', 'e==f', 'x==y', 'y==z', 'z>b', 'x==5', 'd==100'], - ['x==5', 'x<4']] - - -def test_assumption_system(): - for expr, assums, res in assumptions_tests: - equality_subs, all_subs = parse_assumptions(assums, set()) - initial_expr = expr - expr = expr.subs(equality_subs[0]) - expr = expr.subs(equality_subs[1]) - for subs1, subs2 in all_subs: - expr = expr.subs(subs1) - expr = expr.subs(subs2) - assert expr == res - - for assums in tests_for_exception: - # check that the Exception gets raised. - with raises(ContradictingAssumptions): - parse_assumptions(assums, set()) - - -if __name__ == '__main__': - test_work_depth() - test_assumption_system()