This repository has been archived by the owner on Jun 19, 2024. It is now read-only.
forked from pytorch/benchmark
-
Notifications
You must be signed in to change notification settings - Fork 1
/
run_e2e.py
53 lines (48 loc) · 2 KB
/
run_e2e.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import time
import torch
import argparse
import json
from dataclasses import asdict
from torchbenchmark.e2e import E2EBenchmarkResult, load_e2e_model_by_name
from typing import Dict
SUPPORT_DEVICE_LIST = ["cpu", "cuda"]
def run(func) -> Dict[str, float]:
if torch.cuda.is_available():
torch.cuda.synchronize()
result = {}
# Collect time_ns() instead of time() which does not provide better precision than 1
# second according to https://docs.python.org/3/library/time.html#time.time.
t0 = time.time_ns()
func()
if torch.cuda.is_available():
torch.cuda.synchronize()
t2 = time.time_ns()
result["latency_ms"] = (t2 - t0) / 1_000_000.0
return result
def gen_result(m, run_result):
num_epochs = m.num_epochs if hasattr(m, 'num_epochs') else 1
r = E2EBenchmarkResult(device=m.device, device_num=m.device_num,
test=m.test, num_examples=m.num_examples,
num_epochs=num_epochs, batch_size=m.batch_size, result=dict())
r.result["latency"] = run_result["latency_ms"] / 1000.0
r.result["qps"] = r.num_examples / r.result["latency"] * r.num_epochs
# add accuracy result if available
if hasattr(m, "accuracy"):
r.result["accuracy"] = m.accuracy
return r
if __name__ == "__main__":
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("model", help="Full name of the end-to-end model.")
parser.add_argument("-t", "--test", choices=["eval", "train"], default="eval", help="Which test to run.")
parser.add_argument("--bs", type=int, help="Specify batch size.")
args, extra_args = parser.parse_known_args()
found = False
Model = load_e2e_model_by_name(args.model)
if not Model:
print(f"Unable to find model matching {args.model}.")
exit(-1)
m = Model(test=args.test, batch_size=args.bs, extra_args=extra_args)
test = getattr(m, args.test)
result = gen_result(m, run(test))
result_json = json.dumps(asdict(result))
print(result_json)