Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Protobuf definition for DistIR functions #25

Draft
wants to merge 61 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
8ac2c50
Projector: DistIR function -> per-rank function
siddharth-krishna Apr 19, 2021
193dc84
Fix bug: constructing variadic ops when specifying output_values
siddharth-krishna Apr 19, 2021
c5bfbd9
Fix sequential executor docstring
siddharth-krishna Apr 19, 2021
d2f9e0a
A distributed PyTorch backend
siddharth-krishna Apr 25, 2021
d9cc787
Upgrade PyTorch version
siddharth-krishna Apr 25, 2021
a2a4cc2
Add test: one-weird-trick (matmul version)
siddharth-krishna Apr 25, 2021
7ab770d
Refactor run_multiprocess
siddharth-krishna Apr 25, 2021
4e37ce3
Make run_multiprocess take Functions not nn.Modules
siddharth-krishna Apr 25, 2021
9576811
End-of-file newlines
siddharth-krishna Apr 25, 2021
c4ca1a0
Black
siddharth-krishna Apr 27, 2021
a095665
Parametrize pytest test_owt
siddharth-krishna Apr 27, 2021
3279ca1
Revert "Fix bug: constructing variadic ops when specifying output_val…
siddharth-krishna Apr 27, 2021
c84c61d
Fix Op constructor: better handling of variadic ops/pre-created outputs
siddharth-krishna Apr 27, 2021
bb604f6
Add support for Relu
siddharth-krishna Apr 28, 2021
8d8bbda
Add DP test
siddharth-krishna Apr 29, 2021
42751da
Backend: run and time on GPU
siddharth-krishna Apr 30, 2021
f73d33b
Refactor grid_search for interactive use
siddharth-krishna Apr 30, 2021
20bae75
Timing code for CPUs
siddharth-krishna May 2, 2021
8e8513d
Handle ops with multiple outputs
siddharth-krishna May 2, 2021
6f3ae59
Add support for all MLP training ops
siddharth-krishna May 2, 2021
49d385c
Type inference: fix function name bug
siddharth-krishna May 2, 2021
2cce7c3
Prettyprint: support for printing FunctionMakers
siddharth-krishna May 2, 2021
91f5ecf
Fix backend op implementations
siddharth-krishna May 2, 2021
172a5db
Convert per-rank fns to Modules inside each thread
siddharth-krishna May 2, 2021
2ed3cf7
Per-rank projector: remove types
siddharth-krishna May 2, 2021
7eee9f9
Add some more tests
siddharth-krishna May 2, 2021
78da37d
Default number of repetitions = 1
siddharth-krishna May 2, 2021
56e05dc
DHP transform: return separate init_fn and transformed fn
siddharth-krishna May 2, 2021
78bc1e9
Run pytest with root dir included in PYTHONPATH
siddharth-krishna May 2, 2021
8a288b6
Grid search: remove unused imports
siddharth-krishna May 2, 2021
4b2b301
Revert unintended changes
siddharth-krishna May 2, 2021
75ec41a
Move run_pytorch to backend.torch
siddharth-krishna May 2, 2021
6dbf57c
Interpret Function instead of creating fx.Graph
siddharth-krishna May 5, 2021
4a77383
Prettyprint attributes as Python-style kwargs
siddharth-krishna May 5, 2021
e8735c3
Use broadcast with pairwise groups for send/recv on GPUs
siddharth-krishna May 5, 2021
fd2e7b1
Move new tensors to GPU in each op, outputs back to CPU
siddharth-krishna May 5, 2021
31031cc
Add a mock multiprocess backend for debugging
siddharth-krishna May 6, 2021
5d99a62
Use spawn start method for multiprocessing
siddharth-krishna May 6, 2021
2c8852a
Fix MLP DHP tests
siddharth-krishna May 6, 2021
1d54fea
Revert "Fix MLP DHP tests"
siddharth-krishna May 6, 2021
18eaa08
Fix MLP DHP tests for real
siddharth-krishna May 6, 2021
0ce2fdb
Add code to plot grid search results
siddharth-krishna May 7, 2021
5588b87
Don't use globals while multiprocessing
siddharth-krishna May 7, 2021
7d714d4
Fix mock backend, use_gpu=False by default
siddharth-krishna May 7, 2021
8646842
Partial grid search on 4 devices
siddharth-krishna May 7, 2021
1db0736
Support collectives between a subset of ranks
siddharth-krishna May 10, 2021
a79ab67
Bug fixes for distributed process groups (#24)
santhnm2 May 11, 2021
0c0f7f5
Debugging MLP deadlock
siddharth-krishna May 13, 2021
33cefce
Fix collective projector
santhnm2 May 14, 2021
a5c8b34
Enable grid search test again
siddharth-krishna May 13, 2021
ca68ec4
Some comments, debugging code, cuda sync earlier
siddharth-krishna May 20, 2021
b31209d
Projector for gather
siddharth-krishna May 20, 2021
54ce0d2
Don't save input/outputs to file in torch backend
siddharth-krishna May 25, 2021
5d1b63f
Remove unnecessary global
siddharth-krishna May 25, 2021
3466627
Free tensors after use
siddharth-krishna May 25, 2021
7d16c90
Map DistIR devices to pytorch backend ranks
siddharth-krishna May 26, 2021
8c4f5e1
Fix tests
siddharth-krishna May 26, 2021
79d4a65
Some documentation and cleanup
siddharth-krishna May 26, 2021
da7ff5d
Fix comment
siddharth-krishna May 26, 2021
72589b2
Remove experiment code and dead code
siddharth-krishna May 26, 2021
d832fa5
Add protobuf definition for DistIR
santhnm2 May 28, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,4 @@ jobs:
run: python setup.py install

- name: Test with pytest
run: pytest
run: python -m pytest
1 change: 1 addition & 0 deletions dist_ir/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import torch
390 changes: 390 additions & 0 deletions dist_ir/backend/torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,390 @@
from functools import partial
from operator import getitem
import os
import sys
from time import perf_counter
from traceback import print_exc
from typing import Any, Dict, Iterable, List, NamedTuple, Tuple

import torch
import torch.distributed as dist
from torch import fx

from ..executor.rank_projector import project
from ..ir import Function, cpprint, pformat
from ..ir.device import Device


DistributedContext = NamedTuple(
"DistributedContext",
world_size=int,
use_gpu=bool,
# Map from DistIR device to PyTorch backend rank
device_to_rank=Dict[Device, int],
# Maps tuple of ranks to ProcessGroup
groups=Dict[Tuple[int], Any],
# Temp store of group IDs until threads can create ProcessGroups
groups_list=Iterable[Tuple[int]],
)


# TODO organize by category


def _add(x, y, ctx=None):
return torch.add(x, y)


# TODO kwargs of these functions are required, enforce this somewhere
def _allgather(x_i, dim=0, group=None, ctx=None):
xs = [torch.zeros_like(x_i) for _ in range(len(group))]
if ctx.use_gpu:
xs = [x.cuda(dist.get_rank()) for x in xs]

dist.all_gather(xs, x_i, group=ctx.groups[group])
x = torch.cat(xs, dim=dim)
return x


def _allreduce(x, group=None, ctx=None):
dist.all_reduce(x, group=ctx.groups[group])
return x


def _concat2(x, y, dim=None, ctx=None):
return torch.cat((x, y), dim=dim)


def _identity(x, ctx=None):
return x


def _loss(x, y, N=None, ctx=None):
return torch.square(x - y) / N


def _loss_grad(x, y, N=None, ctx=None):
return 2 * (x - y) / N


def _matmul(x, y, ctx=None):
return torch.matmul(x, y)


def _matmul_grad(x, y, dz, ctx=None):
return (torch.matmul(dz, y.T), torch.matmul(x.T, dz))


def _recv(shape=None, from_d=None, group=None, ctx=None):
x = torch.zeros(shape)
src_rank = ctx.device_to_rank[from_d]
if ctx.use_gpu:
x = x.cuda(dist.get_rank())
dist.broadcast(x, src_rank, group=ctx.groups[group])
else:
dist.recv(x, src_rank)
return x


def _relu(x, ctx=None):
return torch.relu(x)


def _relu_grad(x, dy, ctx=None):
dx = dy.clone()
dx[x <= 0] = 0
return dx


def _send(x, to_d=None, group=None, ctx=None):
if ctx.use_gpu:
src_rank = dist.get_rank()
dist.broadcast(x, src_rank, group=ctx.groups[group])
else:
dst_rank = ctx.device_to_rank[to_d]
dist.send(x, dst_rank)
# Note: in a proper backend, might want to concatenate multiple tensors into
# a single buffer and call a single send op


_op_to_torch = {
"Add": _add,
"Concat": _concat2,
"Identity": _identity,
"Loss": _loss,
"LossGrad": _loss_grad,
"MatMul": _matmul,
"MatMulGrad": _matmul_grad,
"RecvP2P": _recv,
"Relu": _relu,
"ReluGrad": _relu_grad,
"SendP2P": _send,
"MPIAllgather": _allgather,
"MPIAllreduce": _allreduce,
}

# Some mock communication ops that return zero tensors of appropriate shape
# to be used in the sequential runner for debugging

_mock_world_size = None


def _mock_allgather(x_i, dim=0, ctx=None):
xs = [torch.zeros_like(x_i) for _ in range(_mock_world_size)]
x = torch.cat(xs, dim=dim)
return x


def _mock_allreduce(x, ctx=None):
return x


def _mock_recv(shape=None, device=None, ctx=None):
x = torch.zeros(shape)
return x


def _mock_send(x, device=None, ctx=None):
pass


_mock_comm_ops = {
"RecvP2P": _mock_recv,
"SendP2P": _mock_send,
"MPIAllgather": _mock_allgather,
"MPIAllreduce": _mock_allreduce,
}

_mock_op_to_torch = {**_op_to_torch, **_mock_comm_ops}


def function_to_module(fn: Function) -> torch.nn.Module:
"""Deprecated. Converts a DistIR Function to a PyTorch nn.Module using
torch.fx.
"""
g = fx.Graph()
value_map = {}

# Convert inputs
for v in fn.inputs:
value_map[v] = g.placeholder(v.name)

# Convert ops
for op in fn.ops:
inputs = tuple(value_map[v] for v in op.inputs)
kwargs = None if op.attributes is None else {**op.attributes}
output = g.call_function(_op_to_torch[op.op_type], inputs, kwargs)
if len(op.outputs) > 1:
for i, v in enumerate(op.outputs):
value_map[v] = g.call_function(getitem, (output, i))
elif len(op.outputs) == 1:
value_map[op.outputs[0]] = output

# Convert outputs
g.output(tuple(value_map[v] for v in fn.outputs))

return fx.GraphModule({}, g)


def run_function(
ctx: DistributedContext,
fn: Function,
inputs: List[Any],
debug_mock=False,
):
"""Runs DistIR Function `fn` on `inputs` in a distributed context `ctx` by
converting each DistIR op to its torch implementation as given in _op_to_torch.
"""
op_to_torch = _mock_op_to_torch if debug_mock else _op_to_torch
value_map = {}

# Add inputs to value_map
for v, x in zip(fn.inputs, inputs):
value_map[v] = x
assert len(fn.inputs) == len(inputs)

# Run ops
for op in fn.ops:
# op_str = pformat(op).replace("\n", " ")
# print(f"{rank}: {op_str}")
# sys.stdout.flush()
inputs = tuple(value_map[v] for v in op.inputs)
kwargs = {} if op.attributes is None else {**op.attributes}
kwargs["ctx"] = ctx

output = op_to_torch[op.op_type](*inputs, **kwargs)

if len(op.outputs) > 1:
assert isinstance(output, tuple)
for i, v in enumerate(op.outputs):
value_map[v] = output[i]
elif len(op.outputs) == 1:
value_map[op.outputs[0]] = output

# Free tensors that are not used again
for v in op.inputs:
if v in value_map and fn.last_use(v) == op and not (v in fn.outputs):
del value_map[v]

# print(f"{rank}: {op_str}")
# sys.stdout.flush()

# Return outputs
return tuple(value_map[v] for v in fn.outputs)


def run_process(ctx, num_warmup_steps, num_repetitions, rank, fn, inputs):
"""The Python function on rank `rank` that runs DistIR function `fn` on
(torch) inputs `inputs`. The function is run
`num_warmup_steps + num_repetitions` times. The outputs of the last run are
returned, along with the last `num_repetitions` runtimes.
"""
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "29500"
backend = "nccl" if ctx.use_gpu else "gloo"
dist.init_process_group(backend, rank=rank, world_size=ctx.world_size)

# Create the process groups used by fn's communication ops
for group in ctx.groups_list:
ranks = [ctx.device_to_rank[d] for d in group]
# ctx is a curried arg, hence is thread-local and can be modified:
ctx.groups[group] = dist.new_group(ranks)

if ctx.use_gpu:
# Move inputs to GPU
inputs = [t.cuda(rank) for t in inputs]

events = []

def add_event():
if ctx.use_gpu:
events.append(torch.cuda.Event(enable_timing=True))
events[-1].record()
else:
events.append(perf_counter())

# Time a bunch of executions, then execute once for output values
add_event()
for _ in range(num_warmup_steps + num_repetitions):
# try:
# outputs = run_function(ctx, fn, inputs)
# except Exception as e:
# print_exc()
# sys.exit(1)
outputs = run_function(ctx, fn, inputs)
if ctx.world_size > 1:
torch.distributed.barrier()
add_event()

if ctx.use_gpu:
# Move outputs back to cpu
outputs = [t.cpu() for t in outputs]

if ctx.use_gpu:
torch.cuda.synchronize()
runtimes = [
events[i].elapsed_time(events[i + 1]) / 1e3 for i in range(len(events) - 1)
]
else:
runtimes = [events[i + 1] - events[i] for i in range(len(events) - 1)]

dist.destroy_process_group()
return outputs, runtimes[num_warmup_steps:]


def run_mock_multiprocess(
per_rank_functions: Tuple[Function],
per_rank_inputs: Tuple[Any],
num_repetitions=1,
num_warmup=0,
):
assert len(per_rank_functions) == len(per_rank_inputs)
global _mock_world_size
_mock_world_size = len(per_rank_functions)
ctx = DistributedContext(use_gpu=False, groups=None)

per_rank_outputs = [
run_function(ctx, fn, inputs, debug_mock=True)
for rank, fn, inputs in zip(
range(_mock_world_size), per_rank_functions, per_rank_inputs
)
]
mock_runtimes = [
[0.0 for _ in range(num_warmup + num_repetitions)]
for _ in range(_mock_world_size)
]
return (per_rank_outputs, mock_runtimes)


def run_multiprocesses(
ctx,
per_rank_functions: Tuple[Function],
per_rank_inputs: Tuple[Any],
num_repetitions=1,
num_warmup=0,
):
assert len(per_rank_functions) == len(per_rank_inputs)
args = [
(r, f, x) for (r, (f, x)) in enumerate(zip(per_rank_functions, per_rank_inputs))
]

per_rank_runner = partial(run_process, ctx, num_warmup, num_repetitions)
mp = torch.multiprocessing.get_context("spawn")
with mp.Pool(ctx.world_size) as p:
outputs = p.starmap(per_rank_runner, args)

per_rank_outputs, runtimes = zip(*outputs)
return per_rank_outputs, runtimes


def run_pytorch(
fn,
inputs,
use_gpu=False,
num_repetitions=1,
num_warmup=0,
debug_mock=False,
):
"""Project `fn` and run on `inputs` over `num_devices` devices using the
PyTorch backend.
"""
# print(*(x.shape for x in inputs))
# cpprint(fn)

device_to_fns, groups = project(fn, tuple(v.type for v in fn.inputs))

# Map between DistIR devices and pytorch ranks:
device_to_rank = {}
world_size = 0
per_rank_fns = []
for d in device_to_fns:
device_to_rank[d] = world_size
per_rank_fns.append(device_to_fns[d])
world_size += 1

ctx = DistributedContext(
world_size=world_size,
use_gpu=use_gpu,
groups={},
groups_list=list(groups),
device_to_rank=device_to_rank,
)

per_rank_inputs = [[] for _ in range(world_size)]
for v, a in zip(fn.inputs, inputs):
per_rank_inputs[device_to_rank[v.type.device]].append(a)

# for xs, per_rank_fn in zip(per_rank_inputs, per_rank_fns):
# print(*(x.shape for x in xs))
# cpprint(per_rank_fn)

if debug_mock:
return run_mock_multiprocess(per_rank_fns, per_rank_inputs)
else:
return run_multiprocesses(
ctx,
per_rank_fns,
per_rank_inputs,
num_repetitions=num_repetitions,
num_warmup=num_warmup,
)
Loading