Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BMTrain now supports 1F1B Pipeline schedule! #190

Open
wants to merge 51 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
9fb40bd
rename pipe env variable name
MayDomine Sep 1, 2023
c5ac256
pipe example test
MayDomine Sep 4, 2023
0f1a28a
Merge branch 'dev' of https://github.com/OpenBMB/BMTrain into dev
MayDomine Sep 4, 2023
22e9bc0
fix 1f1b stuck
MayDomine Sep 4, 2023
7fcbf0f
1F1B Pipeline compitabe with zero
MayDomine Sep 7, 2023
cb81eb9
Merge branch 'dev' of https://github.com/OpenBMB/BMTrain into dev
MayDomine Sep 7, 2023
3c235a4
fix pipe embedding
MayDomine Sep 7, 2023
872285b
fix param init
MayDomine Sep 7, 2023
45e79d7
fix multi step validation and trying to fix embedding tied and datalo…
MayDomine Sep 11, 2023
9c05206
1f1b example
MayDomine Sep 14, 2023
9590493
add debug logger in bmt.init
MayDomine Sep 14, 2023
8dec50d
1f1b inspect and tied embedding
MayDomine Sep 18, 2023
3384c5f
1f1b stable version
MayDomine Sep 20, 2023
e35ff76
WIP: 1f1b training adaption
MayDomine Oct 12, 2023
3683007
1f1b training adaption and fix example
MayDomine Oct 18, 2023
05bb1b3
fix data loader for 1f1b
MayDomine Oct 20, 2023
5b7a18c
fix context for 1f1b
MayDomine Oct 20, 2023
2a1fda1
better example validation
MayDomine Oct 20, 2023
db60ce3
Optimizer for 1f1b adaption
MayDomine Oct 25, 2023
5da4c7f
fix logger
MayDomine Oct 26, 2023
ca2363f
fix logger level
MayDomine Oct 27, 2023
9e219b8
Merge remote-tracking branch 'bmb/dev' into pipe
MayDomine Oct 31, 2023
99eda09
fix scale
MayDomine Nov 3, 2023
586d0b8
fix recv async bug
MayDomine Nov 6, 2023
783adbf
fix lr_scheduler step and delete trash file
MayDomine Nov 6, 2023
5c2222b
avoid comm when no need
MayDomine Nov 8, 2023
72bfc33
fix comm bug
MayDomine Nov 8, 2023
b06c59f
add ckpt args in pipe blocklist
MayDomine Nov 8, 2023
1e687ea
scale loss in 1f1b
MayDomine Nov 9, 2023
7ed0f89
pipeline ckpt store and save
MayDomine Nov 10, 2023
f5933db
clone logits
Achazwl Nov 13, 2023
2c4606a
Merge branch 'dev' into pipe
MayDomine Feb 20, 2024
2ca387a
Merge branch 'dev' into pipe
MayDomine Mar 4, 2024
0bd5f2f
fix init
MayDomine Mar 5, 2024
c8e184f
refactor p2p ops
MayDomine Mar 18, 2024
851711c
formatting pipeline code
MayDomine Mar 18, 2024
5dad728
add grad scale for optim_manager
MayDomine Apr 15, 2024
151f679
fix a typo
MayDomine Apr 22, 2024
d6397e7
WIP: Pipeline example code refactor
MayDomine May 6, 2024
2de6bec
WIP: Pipeline example code refactor
MayDomine May 6, 2024
8b6f8db
WIP: Pipeline example code refactor
MayDomine May 6, 2024
07cc443
WIP: Pipeline example code refactor
MayDomine May 6, 2024
290c1e3
Pipeline example code refactor
MayDomine May 6, 2024
a742092
Pipeline example code refactor
MayDomine May 6, 2024
fd7ac11
support bmt.save/load save model partition instead of whole model
MayDomine May 8, 2024
3fcf4eb
Merge branch 'dev' into pipe
MayDomine May 8, 2024
af3458d
fix load model pipe
MayDomine May 8, 2024
0305c59
fix Block load param logic
MayDomine May 10, 2024
0e22e96
fix load partition
MayDomine May 11, 2024
88601be
fix OOM caused by PipeDreamBlockList.init_param_storage
MayDomine May 13, 2024
053eee5
add check_overflow even no loss scale enable
MayDomine Jul 8, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@ on:
branches:
- 'dev'
- 'main'
push:
branches:
- 'dev'

jobs:
build-archive-wheel:

uses: OpenBMB/BMTrain/.github/workflows/build_whl.yml@main
secrets:
DOCKERHUB_TOKEN: ${{ secrets.DOCKERHUB_TOKEN }}
DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }}
secrets: inherit

publish:
needs: build-archive-wheel
Expand Down
6 changes: 3 additions & 3 deletions bmtrain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .utils import print_block, print_dict, print_rank, see_memory, load_nccl_pypi
from .utils import print_block, print_dict, print_rank, print_rank_pp, see_memory, load_nccl_pypi
try:
from . import nccl
except:
Expand All @@ -10,11 +10,11 @@
from .layer import DistributedModule
from .param_init import init_parameters, grouped_parameters
from .synchronize import synchronize, sum_loss, wait_loader, gather_result
from .block_layer import Block, TransformerBlockList
from .block_layer import Block, TransformerBlockList, PipeDreamBlockList
from .wrapper import BMTrainModelWrapper
from .pipe_layer import PipelineTransformerBlockList
from . import debug
from .store import save, load
from .store import save, load, clean

from . import loss
from . import distributed
Expand Down
27 changes: 27 additions & 0 deletions bmtrain/benchmark/all_reduce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from .. import nccl
from .shape import SHAPES
from ..global_var import config
from ..utils import round_up, print_rank
from .utils import format_size
import torch

def all_reduce():
current_stream = torch.cuda.current_stream()
for shape in SHAPES:
global_size = round_up(shape, config['world_size'] * 2)

partition_tensor = torch.empty( global_size // 2, dtype=torch.half, device="cuda" )
global_tensor = torch.empty( global_size // 2, dtype=torch.half, device="cuda" )

start_evt = torch.cuda.Event(enable_timing=True)
end_evt = torch.cuda.Event(enable_timing=True)

current_stream.record_event(start_evt)
nccl.allReduce(partition_tensor.storage(), global_tensor.storage(),"sum", config['comm'])
current_stream.record_event(end_evt)
current_stream.synchronize()
time_usage = start_evt.elapsed_time(end_evt)

bw = global_size / 1024 / 1024 / 1024 * 1000 / time_usage * 2
print_rank("All reduce:\tsize {}\ttime: {:4.3f}\tbw: {:2.6f} GB/s".format(format_size(global_size), time_usage, bw))

232 changes: 196 additions & 36 deletions bmtrain/block_layer.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion bmtrain/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .ops import all_gather, all_reduce, broadcast, recv_activations, send_activations, reduce_scatter
from .ops import all_gather, all_reduce, broadcast, recv_tensor, send_tensor, groupcall, send_object, recv_object, reduce_scatter
12 changes: 12 additions & 0 deletions bmtrain/distributed/dtype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import torch
DTYPE_LIST = [
torch.float64,
torch.float32,
torch.float16,
torch.int64,
torch.int32,
torch.int16,
torch.int8,
torch.bfloat16,
torch.bool
]
47 changes: 6 additions & 41 deletions bmtrain/distributed/ops.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,13 @@
import torch
from ..global_var import config
from ..nccl import allGather as ncclAllGather, recv
import bmtrain as bmt
from ..global_var import config, rank
from ..nccl import allGather as ncclAllGather
from ..nccl import allReduce as ncclAllReduce
from ..nccl import broadcast as ncclBroadcast
from ..nccl import reduceScatter as ncclReduceScatter
from ..nccl import send as ncclSend
from ..nccl import recv as ncclRecv
from ..nccl import commCount,commRank,NCCLCommunicator
DTYPE_LIST = [
torch.float64,
torch.float32,
torch.float16,
torch.int64,
torch.int32,
torch.int16,
torch.int8,
torch.bfloat16,
torch.bool
]
def send_activations(hidden_state, next_rank, comm):
send_meta(hidden_state, next_rank, comm)
ncclSend(hidden_state.storage(), next_rank, comm)

def recv_activations(prev_rank, comm):
dtype, shape = recv_meta(prev_rank, comm)
hidden_state = torch.empty(shape, dtype=dtype, device="cuda")
ncclRecv(hidden_state.storage(), prev_rank, comm)
return hidden_state

def send_meta(x, next_rank, comm):
meta_data = torch.tensor(data=[0]*50, device="cuda", dtype=torch.int)
meta_data[0] = len(x.size())
meta_data[1] = DTYPE_LIST.index(x.dtype)
meta_data[2:len(x.size())+2] = torch.tensor(x.size(), device="cuda", dtype=torch.int)
meta_data = meta_data.contiguous()
ncclSend(meta_data.storage(), next_rank, comm)

def recv_meta(prev_rank, comm):
meta_data = torch.tensor(data=[0]*50, device="cuda", dtype=torch.int)
ncclRecv(meta_data.storage(), prev_rank, comm)
n_dims = meta_data[0].item()
dtype = DTYPE_LIST[meta_data[1].item()]
shape = meta_data[2:n_dims+2].tolist()
return dtype,shape
from ..nccl import commCount, commRank, NCCLCommunicator, groupStart, groupEnd
from .p2p_ops import *


class OpBroadcast(torch.autograd.Function):

Expand Down
159 changes: 159 additions & 0 deletions bmtrain/distributed/p2p_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import torch
from bmtrain import config
from ..nccl import reduceScatter as ncclReduceScatter
from ..nccl import send as ncclSend
from ..nccl import recv as ncclRecv
from ..nccl import groupStart,groupEnd
from .dtype import DTYPE_LIST
import pickle
import contextlib

_p2p_stream = {}
_p2p_events = {}

@contextlib.contextmanager
def groupcall():
groupStart()
yield
groupEnd()
class handler:
def __init__(self, event):
self.event= event

def wait(self):
torch.cuda.current_stream().wait_event(self.event)

def send_object(obj, peer_rank, comm):
data_bytes: bytes = pickle.dumps(obj)
data_length: int = len(data_bytes)

gpu_data_length = torch.tensor([data_length], device="cuda", dtype=torch.long)
ncclSend(gpu_data_length.storage(), peer_rank, comm)
byte_storage = torch.ByteStorage.from_buffer(data_bytes).cuda()
ncclSend(byte_storage, peer_rank, comm)

def recv_object(peer_rank, comm):
data_length = torch.tensor([0], device="cuda", dtype=torch.long)
ncclRecv(data_length.storage(), peer_rank, comm)
data_bytes_stor = torch.cuda.ByteStorage(data_length.item())
ncclRecv(data_bytes_stor, peer_rank, comm)
tensor = torch.ByteTensor(data_bytes_stor.cpu())
data = pickle.loads(tensor.numpy().tobytes())
return data

def record_stream_helper(tensor_list, stream):
for t in tensor_list:
t.record_stream(stream)

def send_tensors(tensor_list, peer_rank, comm):
handler = _send_tensors(tensor_list, peer_rank, comm)
handler.wait()

def isend_tensor(tensor_list, peer_rank, comm):
return _send_tensors(tensor_list, peer_rank, comm)

def _send_tensors(tensor_list, peer_rank, comm):
p2p_key = f"send {peer_rank}"
if p2p_key not in _p2p_stream:
_p2p_stream[p2p_key] = torch.cuda.Stream()
if p2p_key not in _p2p_events:
_p2p_events[p2p_key] = torch.cuda.Event()
stream = _p2p_stream[p2p_key]
event = _p2p_events[p2p_key]
event.record(torch.cuda.current_stream())
stream.wait_event(event)
with torch.cuda.stream(stream):
length = torch.tensor(data=[len([h for h in tensor_list ])], device="cuda", dtype=torch.int)
flags = torch.tensor(data=[0 for _ in range(len(tensor_list))], device="cuda",dtype=torch.int)
for i in range(len(tensor_list)):
if tensor_list[i] is None:
flag = -1
elif torch.is_tensor(tensor_list[i]):
flag = 0
else:
flag = 1
flags[i] = flag
ncclSend(length.storage(), peer_rank, comm)
ncclSend(flags.contiguous().storage(), peer_rank, comm)
for i in range(len(tensor_list)):
if flags[i] == 0:
tensor_list[i].record_stream(stream)
send_tensor(tensor_list[i], peer_rank, comm)
elif flags[i] == 1:
send_object(tensor_list[i], peer_rank, comm)
event.record(stream)
return handler(event)

def recv_tensors(peer_rank, comm):
tensors, handle = _recv_tensors(peer_rank, comm)
handle.wait()
return tensors

def irecv_tensors(peer_rank, comm):
tensors, handle = _recv_tensors(peer_rank, comm)
return tensors, handle

def _recv_tensors(peer_rank, comm):
p2p_key = f"recv {peer_rank}"
if p2p_key not in _p2p_stream:
_p2p_stream[p2p_key] = torch.cuda.Stream()
if p2p_key not in _p2p_events:
_p2p_events[p2p_key] = torch.cuda.Event()
stream = _p2p_stream[p2p_key]
event = _p2p_events[p2p_key]
with torch.cuda.stream(stream):
length = torch.tensor(data=[0], device="cuda", dtype=torch.int)
tensor_list = []
ncclRecv(length.storage(), peer_rank, comm)
flags = torch.tensor(data=[0 for _ in range(length)], device="cuda",dtype=torch.int)
ncclRecv(flags.storage(), peer_rank, comm)
for i in range(length[0].item()):
flag = flags[i].item()
if flag == -1:
tensor_list.append(None)
elif flag == 0:
recv = recv_tensor(peer_rank, comm)
tensor_list.append(recv)
elif flag == 1:
recv = recv_object(peer_rank, comm)
tensor_list.append(recv)
event.record(stream)
record_stream_helper([tensor_list[i] for i in range(length[0].item()) if flags[i].item() != -1], torch.cuda.current_stream())
return tensor_list, handler(event)

def send_tensor(hidden_state, peer_rank, comm):
hidden_state = hidden_state.contiguous()
send_meta(hidden_state, peer_rank, comm)
ncclSend(hidden_state.storage(), peer_rank, comm)

def send_tensor_inplace(hidden_state, peer_rank, comm):
hidden_state = hidden_state.contiguous()
ncclSend(hidden_state.storage(), peer_rank, comm)

def recv_tensor_inplace(hidden_state, peer_rank, comm):
hidden_state = hidden_state.contiguous()
ncclRecv(hidden_state.storage(), peer_rank, comm)
return hidden_state

def recv_tensor(peer_rank, comm):
dtype, shape = recv_meta(peer_rank, comm)
hidden_state = torch.empty(shape, dtype=dtype, device="cuda")
ncclRecv(hidden_state.storage(), peer_rank, comm)
return hidden_state

def send_meta(x, peer_rank, comm):
meta_data = torch.tensor(data=[0]*50, device="cuda", dtype=torch.int)
meta_data[0] = len(x.size())
meta_data[1] = DTYPE_LIST.index(x.dtype)
meta_data[2:len(x.size())+2] = torch.tensor(x.size(), device="cuda", dtype=torch.int)
meta_data = meta_data.contiguous()
ncclSend(meta_data.storage(), peer_rank, comm)

def recv_meta(peer_rank, comm):
meta_data = torch.tensor(data=[0]*50, device="cuda", dtype=torch.int)
ncclRecv(meta_data.storage(), peer_rank, comm)
n_dims = meta_data[0].item()
dtype = DTYPE_LIST[meta_data[1].item()]
shape = meta_data[2:n_dims+2].tolist()

return dtype,shape
56 changes: 40 additions & 16 deletions bmtrain/hook_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,51 +4,75 @@

def zero_pre_forward(module, inputs):
enter = True
pipe = False
if module._mode == "PIPE":
enter = module._micro_idx == 0
pipe = True
if module._mode == "PIPE" or module._mode == "1F1B":
if not hasattr(module, "_micro_forward_idx") or module._micro_forward_idx == -1:
module._micro_forward_idx = 0
enter = True
else:
enter = False
module._micro_forward_idx += 1
if enter:
zero_level = module._zero_level
forward_flag = 1 if zero_level == 2 else 0
if zero_level == 2 and not module._need_release:
forward_flag = 2 # repeating forward in same layer
if module.all_param_no_grad: #only forward
forward_flag = 0
module._forward_block_ctx = ZeroContext(module, module._layer_dict, pipe=pipe)
module._forward_block_ctx.enter(forward_flag)
if module._mode == "1F1B":
module._block_ctx = ZeroContext(module, module._layer_dict)
module._block_ctx.enter(0, requires_grad=True)
else:
module._forward_block_ctx = ZeroContext(module, module._layer_dict)
module._forward_block_ctx.enter(forward_flag)

def zero_post_forward(module, inputs, outputs):
forward_flag = 1 if module._zero_level == 2 else 0
if module.all_param_no_grad:
forward_flag = 0
exit = True
if module._mode == "PIPE":
exit = module._micro_idx == config['micros'] - 1
if module._mode == "PIPE" or module._mode == "1F1B":
if module._micro_forward_idx == config["micros"] - 1:
module._micro_forward_idx = -1
if module._mode == "1F1B":
exit = False
else:
exit = True
else:
exit = False

if exit:
module._forward_block_ctx.exit(forward_flag)

def zero_pre_backward(module, grad_outputs):
backward_flag = 2 if module._zero_level == 2 else 0
if module._mode != "PIPE":
if module._mode != "PIPE" and module._mode != "1F1B":
module._backward_block_ctx = ZeroContext(module, module._layer_dict)
module._backward_block_ctx.enter(backward_flag, True)
module.release_next_module(backward_flag)
else:
if module._micro_idx == config['micros'] - 1:
module._backward_block_ctx = ZeroContext(module, module._layer_dict, pipe=True)
module._backward_block_ctx.enter(backward_flag, True)
if not hasattr(module, "_micro_backward_idx") or module._micro_backward_idx == -1:
if module._mode == "1F1B":
module._micro_backward_idx = 0
else:
module._micro_backward_idx = 0
module._backward_block_ctx = ZeroContext(module, module._layer_dict)
module._backward_block_ctx.enter(backward_flag,requires_grad=True)
else:
module._micro_backward_idx += 1

def zero_post_backward(module, grad_inputs, grad_outputs):
backward_flag = 2 if module._zero_level == 2 else 0
if module._mode != "PIPE":
if module._mode != "PIPE" and module._mode != "1F1B":
if module._is_first_layer:
module.release(backward_flag)
else:
if module._micro_idx == 0:
module.release(backward_flag)
module._micro_idx -= 1
if module._micro_backward_idx == config["micros"] - 1:
if module._mode == "1F1B":
module._block_ctx.exit(0, backward=True)
config['load_stream'].record_event(config['load_event'])
else:
module.release(backward_flag)
module._micro_backward_idx = -1

class OneStepNoGradFunc(torch.autograd.Function):
"""
Expand Down
Loading
Loading