Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simulator accuracy #34

Open
wants to merge 260 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
260 commits
Select commit Hold shift + click to select a range
a095665
Parametrize pytest test_owt
siddharth-krishna Apr 27, 2021
3279ca1
Revert "Fix bug: constructing variadic ops when specifying output_val…
siddharth-krishna Apr 27, 2021
c84c61d
Fix Op constructor: better handling of variadic ops/pre-created outputs
siddharth-krishna Apr 27, 2021
bb604f6
Add support for Relu
siddharth-krishna Apr 28, 2021
8d8bbda
Add DP test
siddharth-krishna Apr 29, 2021
42751da
Backend: run and time on GPU
siddharth-krishna Apr 30, 2021
f73d33b
Refactor grid_search for interactive use
siddharth-krishna Apr 30, 2021
20bae75
Timing code for CPUs
siddharth-krishna May 2, 2021
8e8513d
Handle ops with multiple outputs
siddharth-krishna May 2, 2021
6f3ae59
Add support for all MLP training ops
siddharth-krishna May 2, 2021
49d385c
Type inference: fix function name bug
siddharth-krishna May 2, 2021
2cce7c3
Prettyprint: support for printing FunctionMakers
siddharth-krishna May 2, 2021
91f5ecf
Fix backend op implementations
siddharth-krishna May 2, 2021
172a5db
Convert per-rank fns to Modules inside each thread
siddharth-krishna May 2, 2021
2ed3cf7
Per-rank projector: remove types
siddharth-krishna May 2, 2021
7eee9f9
Add some more tests
siddharth-krishna May 2, 2021
78da37d
Default number of repetitions = 1
siddharth-krishna May 2, 2021
56e05dc
DHP transform: return separate init_fn and transformed fn
siddharth-krishna May 2, 2021
78bc1e9
Run pytest with root dir included in PYTHONPATH
siddharth-krishna May 2, 2021
8a288b6
Grid search: remove unused imports
siddharth-krishna May 2, 2021
4b2b301
Revert unintended changes
siddharth-krishna May 2, 2021
75ec41a
Move run_pytorch to backend.torch
siddharth-krishna May 2, 2021
3db597b
Inference for GPT2
santhnm2 Apr 8, 2021
0b7b6d5
GPT-2 reference execution output matches
santhnm2 Apr 8, 2021
2d67245
Type inference for GPT-2
santhnm2 Apr 13, 2021
75f9ea3
Add PostTypeInferenceSimulator and update cost functions
santhnm2 Apr 22, 2021
4030ca8
Data parallel transform works for GPT-2
santhnm2 Apr 23, 2021
4b0d6e8
Add pipeline parallel partitioning
santhnm2 Apr 23, 2021
5e39f26
In progress SOSP results
santhnm2 Apr 29, 2021
1e68c25
Add horizontal parallelism for GPT-2
santhnm2 Apr 30, 2021
59e93b0
Add GPT-2 grid search, filter extra outputs, fix pipeline parallel pa…
santhnm2 May 4, 2021
ea3ca07
Add tensor split to MLP weights
santhnm2 May 4, 2021
b35bc12
Change dim to axis
santhnm2 May 5, 2021
13d1c0c
Fix concat attribute in torch backend
siddharth-krishna May 5, 2021
6dbf57c
Interpret Function instead of creating fx.Graph
siddharth-krishna May 5, 2021
4a77383
Prettyprint attributes as Python-style kwargs
siddharth-krishna May 5, 2021
e8735c3
Use broadcast with pairwise groups for send/recv on GPUs
siddharth-krishna May 5, 2021
fd2e7b1
Move new tensors to GPU in each op, outputs back to CPU
siddharth-krishna May 5, 2021
d087d5f
All tests pass
santhnm2 May 5, 2021
b23d50b
Merge with lowering branch
santhnm2 May 5, 2021
71e02e4
Separate initialization from computation for GPT-2
santhnm2 May 5, 2021
31031cc
Add a mock multiprocess backend for debugging
siddharth-krishna May 6, 2021
5d99a62
Use spawn start method for multiprocessing
siddharth-krishna May 6, 2021
2c8852a
Fix MLP DHP tests
siddharth-krishna May 6, 2021
1d54fea
Revert "Fix MLP DHP tests"
siddharth-krishna May 6, 2021
18eaa08
Fix MLP DHP tests for real
siddharth-krishna May 6, 2021
c8ccd64
PyTorch backend working for GPT-2 on a single CPU device
santhnm2 May 7, 2021
cab1146
Merge with origin/lowering
santhnm2 May 7, 2021
0ce2fdb
Add code to plot grid search results
siddharth-krishna May 7, 2021
5588b87
Don't use globals while multiprocessing
siddharth-krishna May 7, 2021
7d714d4
Fix mock backend, use_gpu=False by default
siddharth-krishna May 7, 2021
cb63add
Merge with origin/lowering
santhnm2 May 7, 2021
23cc8af
Merge remote-tracking branch 'origin/lowering' into gpt_with_lowering
santhnm2 May 7, 2021
8646842
Partial grid search on 4 devices
siddharth-krishna May 7, 2021
1db0736
Support collectives between a subset of ranks
siddharth-krishna May 10, 2021
fe1f239
Add backend support for GPT-2 grid search
santhnm2 May 10, 2021
720a793
Merge with origin/lowering
santhnm2 May 10, 2021
2df9e74
In progress backend fixes
santhnm2 May 11, 2021
473a125
Remove debug print
santhnm2 May 11, 2021
a79ab67
Bug fixes for distributed process groups (#24)
santhnm2 May 11, 2021
d66c35f
Fix Send/Recv dtypes
santhnm2 May 12, 2021
0c0f7f5
Debugging MLP deadlock
siddharth-krishna May 13, 2021
11ed80d
Fix pipeline parallel output forwarding and add PyTorch profiling to …
santhnm2 May 14, 2021
33cefce
Fix collective projector
santhnm2 May 14, 2021
a5c8b34
Enable grid search test again
siddharth-krishna May 13, 2021
ca68ec4
Some comments, debugging code, cuda sync earlier
siddharth-krishna May 20, 2021
b31209d
Projector for gather
siddharth-krishna May 20, 2021
54ce0d2
Don't save input/outputs to file in torch backend
siddharth-krishna May 25, 2021
5d1b63f
Remove unnecessary global
siddharth-krishna May 25, 2021
3466627
Free tensors after use
siddharth-krishna May 25, 2021
7d16c90
Map DistIR devices to pytorch backend ranks
siddharth-krishna May 26, 2021
8c4f5e1
Fix tests
siddharth-krishna May 26, 2021
79d4a65
Some documentation and cleanup
siddharth-krishna May 26, 2021
da7ff5d
Fix comment
siddharth-krishna May 26, 2021
72589b2
Remove experiment code and dead code
siddharth-krishna May 26, 2021
c9eb6de
GPT-2 updates
santhnm2 May 28, 2021
a22237f
Add docs and input validation for run_pytorch
siddharth-krishna Jun 3, 2021
3a9d264
Profiling seq_mlp using pytorch/tensorboard
siddharth-krishna Jun 3, 2021
3562e8c
Merge with main
santhnm2 Jun 4, 2021
54e94e6
Add test init file
santhnm2 Jun 4, 2021
3263880
Allow for specifying number of GPT transformer blocks
santhnm2 Jun 9, 2021
4ed6a35
Fix pipeline parallel scheduling
santhnm2 Jun 14, 2021
149ecde
Grid search fixes
santhnm2 Jun 17, 2021
c969cf2
Grid search fixes
santhnm2 Jun 17, 2021
0306376
Sort ready stages by microbatch ID
santhnm2 Jun 23, 2021
c69658f
Fix formatting
santhnm2 Jun 23, 2021
8328fde
Add roundrobin to requirements.txt
santhnm2 Jun 23, 2021
f328e18
More formatting fixes
santhnm2 Jun 23, 2021
5842bd0
Update grid search with GPT-3 model sizes
santhnm2 Jun 23, 2021
1ff97e6
Enable mixed type interpretation through the SequentialExecutor
santhnm2 Jun 24, 2021
e91b76d
Only use real weights if flag is enabled
santhnm2 Jun 24, 2021
d5f7141
Only forward pipeline parallel inputs if necessary
santhnm2 Jun 24, 2021
ba47682
Clean up split ops
santhnm2 Jun 24, 2021
df0d66e
Factor out mixed implementations to new file
santhnm2 Jun 24, 2021
d98acd6
Make the gpt code more modular
santhnm2 Jun 25, 2021
13b5951
Parallelized grid search
santhnm2 Jun 25, 2021
af041e5
Add filelock requirement
santhnm2 Jun 25, 2021
69570f4
Fix try-catch
santhnm2 Jun 25, 2021
232e784
Address some of Sid's comments
santhnm2 Jun 29, 2021
7e981fa
Merge with main
santhnm2 Jun 30, 2021
fde6fb9
Address more of Sid's comments
santhnm2 Jul 1, 2021
73a4e1a
Fix formatting
santhnm2 Jul 1, 2021
6d08375
Add TODOs for removing code pending mixed implementations
santhnm2 Jul 1, 2021
b19a4ea
Add docstring to get_all_degrees
santhnm2 Jul 1, 2021
7e128b3
Address more of Sid's comments
santhnm2 Jul 1, 2021
f8df2b9
Merge with gpt_with_lowering
santhnm2 Jul 1, 2021
a575d9a
Merge with main
santhnm2 Jul 8, 2021
9260018
Remove unnecessary changes
santhnm2 Jul 8, 2021
ae07996
[WIP] Merge with mixed_type_inference
santhnm2 Jul 12, 2021
c33a81e
Update Constant device in transform
santhnm2 Jul 12, 2021
ca5875b
Finish merging with mixed_type_inference
santhnm2 Jul 13, 2021
9d7de4b
Remove filelock dependency
santhnm2 Jul 13, 2021
19adfde
[WIP] MLP training results + merge MLP transform with GPT-2 transform
santhnm2 Jul 13, 2021
1243ea1
Add a ConcreteValue class used by simulator
siddharth-krishna Jul 13, 2021
8e08e8f
Use actual outputs in simulator's live memory estimation
siddharth-krishna Jul 13, 2021
d91c8a0
Simulator: use default cost function if not registered
siddharth-krishna Jul 13, 2021
7b0f471
Update MLP transform to more closely match GPT transform
santhnm2 Jul 14, 2021
b8cb938
MLP grid search updates
santhnm2 Jul 16, 2021
ced8bf9
First pass at optimizer op
santhnm2 Jul 21, 2021
d7d1380
Make optimizer ops work with distributed execution
santhnm2 Jul 21, 2021
c7839dd
Fix formatting
santhnm2 Jul 21, 2021
1bba41e
Add benchmark for measuring simulator accuracy
santhnm2 Jul 22, 2021
d6a7c36
Update MLP benchmark
santhnm2 Jul 23, 2021
25f6353
Merge with main
santhnm2 Jul 24, 2021
49a72cc
Add separate execution modes to MLP benchmark
santhnm2 Jul 24, 2021
50cc14e
Fix if/else conditions
santhnm2 Jul 27, 2021
d81bc4c
Add profiling
santhnm2 Jul 27, 2021
c822b49
Use torch.relu
santhnm2 Jul 27, 2021
2c51c5d
Fix time measurement for pure PyTorch
santhnm2 Jul 28, 2021
80db02e
Make device parameters configurable
santhnm2 Jul 29, 2021
0e3b139
Dispatch in most-precise-to-most-abstract order
siddharth-krishna Aug 2, 2021
92cf450
Move type register to separate file
siddharth-krishna Aug 2, 2021
02e2631
Use type abstraction graph for dispatch
siddharth-krishna Aug 2, 2021
321af26
SequentialExecutor: use new unified interpreter
siddharth-krishna Aug 2, 2021
4c01eb7
Add device parameters as arguments
santhnm2 Aug 3, 2021
8b09ac5
Add function to calibrate simulator
santhnm2 Aug 3, 2021
3b621b5
Add kernel launch overhead to device parameters
santhnm2 Aug 6, 2021
7a48fde
Updates to benchmark
santhnm2 Aug 6, 2021
a72282b
Update requirements.txt
santhnm2 Aug 7, 2021
c3111f9
Use float32 data type
santhnm2 Aug 11, 2021
576b1c8
Force intercept to be positive
santhnm2 Aug 11, 2021
b769b1c
Wrap inputs to unified interpreter in ConcreteValue
siddharth-krishna Aug 16, 2021
b5688e8
Refactor simulator to use unified interpreter
siddharth-krishna Aug 17, 2021
bd45e25
Refactor projector to use unified interpreter
siddharth-krishna Aug 17, 2021
58dd879
Clean up simulator.py
siddharth-krishna Aug 17, 2021
d9a6b86
Clean up sequential_executor.py
siddharth-krishna Aug 17, 2021
ce8c66c
Merge branch 'main' into simulator-refactor
siddharth-krishna Aug 17, 2021
42fa21c
Add networkx to required packages
siddharth-krishna Aug 17, 2021
bdf48ed
Clean up absint.py
siddharth-krishna Aug 17, 2021
357e5e3
Attempt to fix GPT grid search
siddharth-krishna Aug 17, 2021
e35edfd
[WIP] Add network bandwidth calibration and distributed grid search
santhnm2 Aug 17, 2021
1bc071f
Use Git LFS to download GPT onnx model
siddharth-krishna Aug 18, 2021
9f28fdc
Download onnx/models to /tmp to avoid Black errors
siddharth-krishna Aug 18, 2021
590e98c
Disable torch backend tests
siddharth-krishna Aug 18, 2021
711bc53
Temporarily disable GPT tests
siddharth-krishna Aug 18, 2021
59d6e37
[WIP] Merge with main
santhnm2 Aug 21, 2021
9961f86
Fix gpt2
santhnm2 Aug 23, 2021
64d8563
Update network calibration
santhnm2 Aug 23, 2021
8d33c65
Increase batch size for DGX run
santhnm2 Aug 23, 2021
afe629e
Fix configs for distributed grid search
santhnm2 Aug 23, 2021
07892d6
Load/save simulation parameters from/to a file
santhnm2 Aug 23, 2021
a51c539
Address Sid's comments
santhnm2 Aug 23, 2021
96a3d4d
Bug fixes
santhnm2 Aug 23, 2021
9b6a1f9
Merge with grid_search_optimizations
santhnm2 Aug 23, 2021
b891894
GPT-2 simulation working
santhnm2 Aug 24, 2021
771bc18
Add communication register
santhnm2 Aug 24, 2021
067821a
Remove dead code
santhnm2 Aug 24, 2021
42da615
Update allreduce benchmark
santhnm2 Aug 24, 2021
a8d629b
Add examples/calibrate_simulator.py
santhnm2 Aug 24, 2021
3416e97
Merge with grid_search_optimizations
santhnm2 Aug 25, 2021
7d030c3
Revert changes to simulator
santhnm2 Aug 25, 2021
4ec4ded
[WIP] updated allreduce cost function and grid search updates
santhnm2 Aug 25, 2021
124b189
[WIP] MLP grid search results for presentation
siddharth-krishna Aug 25, 2021
944544f
WIP gpt2 pytorch grid search
santhnm2 Aug 26, 2021
b567f57
add gpt2 benchmark
santhnm2 Aug 26, 2021
02742ed
Add pytorch gpt2 grid search
santhnm2 Aug 26, 2021
0afd3af
Merge with main
santhnm2 Aug 27, 2021
76a8560
Add additional TODO item
santhnm2 Aug 27, 2021
d4074b1
Docstring
siddharth-krishna Aug 26, 2021
5b91e64
Abstract values when necessary during dispatch
siddharth-krishna Aug 27, 2021
c000b4b
Fix simulator's abstraction during dispatch
siddharth-krishna Aug 27, 2021
9eaaa52
Fix projector
siddharth-krishna Aug 27, 2021
73a8b69
GPT-6.7B grid search for ORT presentation
siddharth-krishna Aug 27, 2021
b29a395
MLP pytorch gridsearch
siddharth-krishna Aug 27, 2021
05ec369
Simplify rank projector, don't rely on function's type attributes
siddharth-krishna Aug 27, 2021
f863d3a
Clean-up
siddharth-krishna Aug 27, 2021
79dbe92
Add input_types argument to run_pytorch
santhnm2 Aug 30, 2021
27661d4
Fix send type inference / projection and add GPT2 pytorch tests
santhnm2 Aug 30, 2021
86594cd
Fix formatting
santhnm2 Aug 30, 2021
508a223
Update requirements.txt
santhnm2 Aug 30, 2021
e2a86c7
Clean up code
santhnm2 Aug 30, 2021
b540dec
fixes
santhnm2 Aug 30, 2021
396fb0a
Merge with mlp_training
santhnm2 Aug 30, 2021
2663fc3
Fix global group for distributed barrier
santhnm2 Aug 31, 2021
806e3e3
Fix send cost function and start merging mlp_benchmark with mlp_grid_…
santhnm2 Aug 31, 2021
267fc4f
Add allreduce parameter calibration
santhnm2 Aug 31, 2021
a4bb0fc
Add notebook for MLP training simulator accuracy
santhnm2 Aug 31, 2021
591e913
Cleanup
siddharth-krishna Sep 1, 2021
f68db1c
More absint tests
siddharth-krishna Sep 1, 2021
f7150e6
Test pytorch backend with custom input_types
siddharth-krishna Sep 1, 2021
55ddd3a
Test simulator works on untyped function
siddharth-krishna Sep 1, 2021
9085efd
Add docstrings
siddharth-krishna Sep 1, 2021
b57a5d6
Clean up chrome trace test
siddharth-krishna Sep 1, 2021
87f8418
Test output correctness for GPT with PyTorch backend
santhnm2 Sep 1, 2021
343a8be
Address Sid's comments
santhnm2 Sep 1, 2021
6fc0ade
Abstract input sequences for GPT and simplify tests
santhnm2 Sep 1, 2021
51620ef
Fix test
santhnm2 Sep 1, 2021
85b8287
[WIP] Record op-level traces
santhnm2 Sep 2, 2021
2a7881c
Time each op with torch.cuda.synchronize
santhnm2 Sep 3, 2021
f8cd97a
Address Sid's comments
santhnm2 Sep 3, 2021
bae172b
Merge branch 'simulator-refactor' into mlp_training
siddharth-krishna Sep 3, 2021
ab376a2
Fix mlp_grid_search to use new simulator
siddharth-krishna Sep 2, 2021
68346f2
Fix SGD optimizer and warnings
santhnm2 Sep 4, 2021
0db2160
Formatting fix
santhnm2 Sep 4, 2021
9eebe67
[WIP] Consolidated grid search infrastructure
santhnm2 Sep 5, 2021
c1c0b88
Fix GPT-2 grid search and memory estimation
santhnm2 Sep 5, 2021
28ebcaf
Replace GPT2 grid search
santhnm2 Sep 5, 2021
eefc1b1
Fix tests
santhnm2 Sep 5, 2021
4f58091
Update MLP grid search
santhnm2 Sep 5, 2021
ce62de3
Add grid search tests
santhnm2 Sep 6, 2021
2a9d296
Add constants file
santhnm2 Sep 6, 2021
1c24259
Merge branch 'main' into mlp_training
siddharth-krishna Sep 6, 2021
7bfb52a
Defer input data generation to per-process execution
santhnm2 Sep 6, 2021
98117c9
Add grid search tests
santhnm2 Sep 6, 2021
37ec418
[WIP] debugging send inconsistencies
santhnm2 Sep 7, 2021
1ad646e
Add default CPU->GPU bandwidth
santhnm2 Sep 7, 2021
39a69ab
Add send benchmark
santhnm2 Sep 7, 2021
072be32
Grid search fixes
santhnm2 Sep 7, 2021
77eb8b2
Merge branch 'simulator_accuracy' of github.com:microsoft/dist-ir int…
santhnm2 Sep 7, 2021
2fed503
Allgather cost function fix
santhnm2 Sep 7, 2021
93cb506
grid search: if output file exists, warn and append
siddharth-krishna Sep 7, 2021
c2532f0
Use underscores consistently for command line args
siddharth-krishna Sep 8, 2021
07a62bf
Fix gpt and gpt grid search
siddharth-krishna Sep 8, 2021
e657d87
More gpt fixes
siddharth-krishna Sep 8, 2021
fc925e8
Don't use tqdm for backend grid search, catch RuntimeErrors
siddharth-krishna Sep 8, 2021
9e5c36c
Use NamedTuple for grid search configs
siddharth-krishna Sep 8, 2021
be3ed43
Grid search: skip configs already in output file
siddharth-krishna Sep 8, 2021
0c7e60d
Fix: revert to using lock for backend grid search
siddharth-krishna Sep 8, 2021
cfc8eb3
Run pytorch grid search in ascending memory order
santhnm2 Sep 9, 2021
6ec8625
Fix grid search tests
santhnm2 Sep 9, 2021
ef8ac41
Feature: run single or multiple configs from file
siddharth-krishna Sep 9, 2021
1b4c151
Make batch size an argument to loss function
santhnm2 Sep 9, 2021
b6283a6
Number configurations from 1 not 0
siddharth-krishna Sep 9, 2021
cb1fede
Add argument to append to output file
siddharth-krishna Sep 9, 2021
a581338
Add shell script to run backend grid search
siddharth-krishna Sep 9, 2021
7b63d54
Fix tests
santhnm2 Sep 9, 2021
618bd15
[WIP] Merge with mlp_training
santhnm2 Sep 10, 2021
39269ee
Remove recv buffers
santhnm2 Sep 11, 2021
5dacd53
Update notebook
santhnm2 Sep 13, 2021
1b87689
Merge with main
santhnm2 Sep 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 131 additions & 25 deletions dist_ir/backend/torch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import itertools
import json
from functools import partial
import numpy as np
from operator import getitem
Expand All @@ -6,6 +8,7 @@
from time import perf_counter
from traceback import print_exc
from typing import Any, Dict, Iterable, List, NamedTuple, Sequence, Tuple
import time

import torch
import torch.distributed as dist
Expand All @@ -30,10 +33,14 @@
groups=Dict[Tuple[int], Any],
# Temp store of group IDs until threads can create ProcessGroups
groups_list=Iterable[Tuple[int]],
# Group encompassing all devices
global_group=Tuple[int],
# Debug flag
debug_stacktrace=bool,
# Profile flag
profile=bool,
# List of op execution events
trace=list,
)


Expand Down Expand Up @@ -73,7 +80,7 @@ def _cast(x, to, ctx=None):
raise NotImplementedError()


def _concat2(*args, axis=None, ctx=None):
def _concat(*args, axis=None, ctx=None):
return torch.cat(args, dim=axis)


Expand Down Expand Up @@ -162,14 +169,25 @@ def _reshape(x, y, ctx=None):


def _recv(shape=None, from_d=None, group=None, dtype=None, ctx=None):
if isinstance(dtype, Int32):
x = torch.zeros(shape).int()
elif isinstance(dtype, Int64):
x = torch.zeros(shape).long()
elif isinstance(dtype, Float32):
x = torch.zeros(shape).float()
# torch.distributed.barrier(group=ctx.groups[group])
if len(shape) == 0:
if isinstance(dtype, Int32):
x = torch.tensor(0).int()
if isinstance(dtype, Int64):
x = torch.tensor(0).long()
elif isinstance(dtype, Float32):
x = torch.tensor(0).float()
else:
raise NotImplementedError(dtype)
else:
raise NotImplementedError(dtype)
if isinstance(dtype, Int32):
x = torch.zeros(shape).int()
if isinstance(dtype, Int64):
x = torch.zeros(shape).long()
elif isinstance(dtype, Float32):
x = torch.zeros(shape).float()
else:
raise NotImplementedError(dtype)

src_rank = ctx.device_to_rank[from_d]
if ctx.use_gpu:
Expand All @@ -191,6 +209,7 @@ def _relu_grad(x, dy, ctx=None):


def _send(x, to_d=None, group=None, ctx=None):
# torch.distributed.barrier(group=ctx.groups[group])
if ctx.use_gpu:
src_rank = dist.get_rank()
dist.broadcast(x, src_rank, group=ctx.groups[group])
Expand Down Expand Up @@ -277,7 +296,7 @@ def _unsqueeze(x, axes, ctx=None):
"Add": torch.add,
"Cast": _cast,
"Add": _add,
"Concat": _concat2,
"Concat": _concat,
"Constant": _constant,
"ConstantOfShape": _constant_of_shape,
"Div": _div,
Expand Down Expand Up @@ -379,16 +398,26 @@ def function_to_module(fn: Function) -> torch.nn.Module:
return fx.GraphModule({}, g)


def add_event(ctx, events):
if ctx.use_gpu:
events.append(torch.cuda.Event(enable_timing=True))
events[-1].record()
else:
events.append(perf_counter())


def run_function(
ctx: DistributedContext,
fn: Function,
inputs: List[Any],
rank: int,
debug_mock=False,
op_runtimes_ts: float = None,
):
"""Runs DistIR Function `fn` on `inputs` in a distributed context `ctx` by
converting each DistIR op to its torch implementation as given in _op_to_torch.
"""
record_op_runtimes = op_runtimes_ts is not None
op_to_torch = _mock_op_to_torch if debug_mock else _op_to_torch
value_map = {}

Expand All @@ -403,13 +432,39 @@ def print_memory_usage():
a = torch.cuda.memory_allocated(0)
print(f"Total: {t} Reserved: {r} Allocated: {a} Free: {r-a}")

if record_op_runtimes:
op_runtimes = []

# Run ops
for op in fn.ops:
inputs = tuple(value_map[v] for v in op.inputs)
kwargs = {} if op.attributes is None else {**op.attributes}
kwargs["ctx"] = ctx

# TODO: Consider adding this to mitigate network contention:
# if "MPI" in op.op_type or op.op_type == "Send":
# torch.cuda.synchronize()

if record_op_runtimes:
start = time.time()
output = op_to_torch[op.op_type](*inputs, **kwargs)
if record_op_runtimes:
if ctx.use_gpu:
torch.cuda.synchronize(device=rank)
end = time.time()
if op.op_type == "SendP2P":
x = inputs[0]
src_rank = dist.get_rank()
dst_rank = ctx.device_to_rank[kwargs["to_d"]]
group = ctx.groups[kwargs["group"]]
latency = end - start
print(
f"Sending tensor of size {x.size()} on device {x.device} with dtype "
f"{x.dtype} from device {src_rank} to {dst_rank}: latency={latency}, "
f"throughput={x.shape[0] * x.shape[1] * 4 / 1.25e8 / latency}"
)

op_runtimes.append(end - start)

if len(op.outputs) > 1:
assert isinstance(output, tuple)
Expand All @@ -423,6 +478,24 @@ def print_memory_usage():
if v in value_map and fn.last_use(v) == op and not (v in fn.outputs):
del value_map[v]

if record_op_runtimes:
trace = []
ts = op_runtimes_ts
assert len(fn.ops) == len(op_runtimes)
for op, runtime in zip(fn.ops, op_runtimes):
trace.append(
{
"name": op.op_type,
"ph": "X",
"ts": ts,
"dur": runtime * 1e6,
"pid": 0,
"tid": rank + 1,
}
)
ts += runtime * 1e6
ctx.trace[rank] += trace

# Return outputs
return tuple(value_map[v] for v in fn.outputs)

Expand All @@ -443,20 +516,15 @@ def run_process(ctx, num_warmup_steps, num_repetitions, rank, fn, inputs):
ranks = [ctx.device_to_rank[d] for d in group]
# ctx is a curried arg, hence is thread-local and can be modified:
ctx.groups[group] = dist.new_group(ranks)
global_group_ranks = sorted([ctx.device_to_rank[d] for d in ctx.global_group])
global_group = dist.new_group(global_group_ranks)

if ctx.use_gpu:
# Move inputs to GPU
inputs = [t.cuda(rank) for t in inputs]

events = []

def add_event():
if ctx.use_gpu:
events.append(torch.cuda.Event(enable_timing=True))
events[-1].record()
else:
events.append(perf_counter())

if ctx.profile:
num_wait_steps = 0
else:
Expand All @@ -466,7 +534,7 @@ def add_event():
try:
outputs = run_function(ctx, fn, inputs, rank)
if ctx.world_size > 1:
torch.distributed.barrier()
torch.distributed.barrier(group=global_group)
except Exception as e:
print_exc()
print(f"{rank}: PyTorch backend exiting after 1 run in debug mode.")
Expand All @@ -482,18 +550,32 @@ def add_event():
schedule=torch.profiler.schedule(
wait=num_wait_steps, warmup=num_warmup_steps, active=num_repetitions
),
on_trace_ready=torch.profiler.tensorboard_trace_handler(
f"{fn.name}_{rank}_profile"
),
on_trace_ready=torch.profiler.tensorboard_trace_handler(f"{fn.name}_profile"),
) as p:
op_runtimes_ts = None
for i in range(num_warmup_steps + num_repetitions):
add_event()
record_op_runtimes = ctx.profile and i >= num_warmup_steps
if record_op_runtimes and op_runtimes_ts is None:
op_runtimes_ts = 0.0
add_event(ctx, events)
# TODO: Handle failures here?
outputs = run_function(ctx, fn, inputs, rank)
outputs = run_function(
ctx,
fn,
inputs,
rank,
op_runtimes_ts=op_runtimes_ts,
)
if ctx.world_size > 1:
torch.distributed.barrier()
add_event()
torch.distributed.barrier(group=global_group)
if i == (num_warmup_steps + num_repetitions - 1):
add_event(ctx, events)
p.step()
if record_op_runtimes:
op_runtimes_ts = max(
ctx.trace[rank][-1]["ts"] + ctx.trace[rank][-1]["dur"]
for rank in ctx.trace.keys()
)

if ctx.use_gpu:
# Move outputs back to cpu
Expand Down Expand Up @@ -553,6 +635,11 @@ def run_multiprocesses(
if ctx.debug_stacktrace:
sys.exit(1)

if ctx.profile:
trace = list(itertools.chain.from_iterable(list(ctx.trace.values())))
with open(f"{per_rank_functions[0].name}_profile/trace.json", "w") as f:
json.dump(trace, f, indent=0)

per_rank_outputs, runtimes = zip(*outputs)
return per_rank_outputs, runtimes

Expand Down Expand Up @@ -592,23 +679,42 @@ def run_pytorch(

device_to_fns, groups = project(fn, input_types)

if len(device_to_fns) > torch.cuda.device_count():
raise ValueError(
f"Received {len(device_to_fns)} projected functions, "
f"but only {torch.cuda.device_count()} GPUs available"
)

# Map between DistIR devices and pytorch ranks:
device_to_rank = {}
world_size = 0
per_rank_fns = []
for d in device_to_fns:
device_to_rank[d] = world_size
rank = world_size
device_to_rank[d] = rank
per_rank_fns.append(device_to_fns[d])
world_size += 1

global_group = tuple(sorted(device_to_fns.keys()))

if profile:
manager = torch.multiprocessing.Manager()
trace = manager.dict()
for d in sorted(device_to_rank.keys()):
trace[device_to_rank[d]] = []
else:
trace = None

ctx = DistributedContext(
world_size=world_size,
use_gpu=use_gpu,
groups={},
groups_list=list(groups),
global_group=global_group,
device_to_rank=device_to_rank,
debug_stacktrace=debug_stacktrace,
profile=profile,
trace=trace,
)

per_rank_inputs = [[] for _ in range(world_size)]
Expand Down
6 changes: 6 additions & 0 deletions dist_ir/executor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from .absint import AbstractInterpreter, AbstractState
from .calibrate_simulator import (
calibrate_device_parameters,
calibrate_network_bandwidth,
calibrate_allreduce_parameters,
network_bandwidth_debug, # TODO: Remove
)
from .concrete_value import ConcreteValue
from .cost_model import CostModel
from .simulator import Simulator
Expand Down
Loading