-
Notifications
You must be signed in to change notification settings - Fork 0
/
profiling.py
70 lines (54 loc) · 1.61 KB
/
profiling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import jax
import flax.linen.module as module_lib
from flax.linen.summary import (
_get_call_flops,
_bytes_repr
)
from typing import Tuple, List
from dataclasses import dataclass
from flax.linen import Conv
@dataclass
class ModuleCall:
path: Tuple[str, ...]
method: str
flops: float
vjp_flops: float
DEVICE_PEAK_FLOPS = {
"NVIDIA H100 80GB HBM3": {
"fp32": 5.1e13,
"fp16": 1.513e15,
}
}
def get_peak_flops() -> float:
device_kind = jax.devices()[0].device_kind
peak_flops = DEVICE_PEAK_FLOPS[device_kind]["fp32"]
return peak_flops
def trace_module_calls(module: module_lib.Module, *args, **kwargs) -> List[ModuleCall]:
"""
Get the FLOPs estimate and parameter count for a Flax module.
"""
with module_lib._tabulate_context():
def _get_variables():
return module.init(*args, **kwargs)
jax.eval_shape(_get_variables)
calls = module_lib._context.call_info_stack[-1].calls
calls.sort(key=lambda c: c.index)
calls_out: List[ModuleCall] = []
for c in calls:
flops, vjp_flops = _get_call_flops(c, True, True)
calls_out.append(
ModuleCall(
c.path,
c.method,
flops,
vjp_flops,
)
)
return calls_out
def memory_usage_params(model_params):
total_bytes, total_params = 0, 0
for param in jax.tree_leaves(model_params):
total_bytes += param.size * param.dtype.itemsize
total_params += param.size
total_bytes = _bytes_repr(total_bytes)
return total_bytes, total_params