Skip to content

Commit

Permalink
saving rewrites [run_process_replay] (tinygrad#6501)
Browse files Browse the repository at this point in the history
* save rewrites with TRACK_MATCH_STATS=2 [run_process_replay]

* cleaner
  • Loading branch information
geohot authored Sep 13, 2024
1 parent 7c07819 commit 774bf39
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions tinygrad/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,7 @@ def rewrite(self, uop:UOp) -> Optional[UOp]:
# *** tracking pattern matcher ***

TRACK_MATCH_STATS = getenv("TRACK_MATCH_STATS", 0)
contexts: List[Tuple[UOp, List[Tuple[UOp, UOp]]]] = []
match_stats:Dict[UPat, List[Union[int, float]]] = dict()
class TrackedPattenMatcher(PatternMatcher):
def __init__(self, patterns:List[Tuple[UPat, Callable]]):
Expand All @@ -740,23 +741,27 @@ def rewrite(self, uop:UOp) -> Optional[UOp]:
match_stats[p][0] += 1
match_stats[p][2] += (et:=time.perf_counter()-st)
match_stats[p][3] += et
if TRACK_MATCH_STATS >= 2: print(f"{et*1e6:7.2f} us -- ", p.printable())
if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", p.printable())
if TRACK_MATCH_STATS >= 2: contexts[-1][1].append((uop, ret))
return ret # NOTE: if it returns None, we keep trying to match
match_stats[p][2] += time.perf_counter()-st
return None

if TRACK_MATCH_STATS:
PatternMatcher = TrackedPattenMatcher # type: ignore
import atexit
import atexit, pickle
@atexit.register
def print_match_stats():
ret = [0,0,0.0,0.0]
for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]):
loc_str = f"{k.location[0].split('/')[-1]}:{k.location[1]}"
if getenv("UPAT_FILE", loc_str) not in loc_str: continue
print(f"{v[0]:6d} / {v[1]:7d} -- {v[3]*1000.:9.2f} / {v[2]*1000.:9.2f} ms -- {loc_str:15s}", k.printable())
if v[1] != 0: print(f"{v[0]:6d} / {v[1]:7d} -- {v[3]*1000.:9.2f} / {v[2]*1000.:9.2f} ms -- {loc_str:15s}", k.printable())
ret = [x+y for x,y in zip(ret, v)]
print(f"{ret[0]:6d} / {ret[1]:7d} -- {ret[3]*1000.:9.2f} / {ret[2]*1000.:9.2f} ms -- TOTAL")
if TRACK_MATCH_STATS >= 2:
with open("/tmp/rewrites.pkl", "wb") as f:
print(f"rewrote {len(contexts)} graphs and applied {sum(len(x[1]) for x in contexts)} rules, saved to /tmp/rewrites.pkl")
pickle.dump(contexts, f)

# *** simple graph rewrite engine ***

Expand All @@ -773,4 +778,6 @@ def rewrite(self, n:UOp) -> UOp:
x = UOp(*replace_source) if new_src != n.src else n
self.nodes[replace_source] = self.replace[n] = found = self.rewrite(new_x) if (new_x := self.pm.rewrite(x)) else x
return found
def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp: return RewriteContext(pm).rewrite(sink)
def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp:
if TRACK_MATCH_STATS >= 2: contexts.append((sink, []))
return RewriteContext(pm).rewrite(sink)

0 comments on commit 774bf39

Please sign in to comment.