forked from hpcaitech/PaLM-colossalai
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
118 lines (95 loc) · 3.67 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import contextlib
import colossalai
import torch
from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.trainer import Trainer
from colossalai.utils import MultiTimer, get_current_device
from model import build_model, build_loss
from utils import calc_model_size, AutoregressiveWrapper
from colossalai.zero.init_ctx import ZeroInitContext
from data import build_data
def train_palm():
disable_existing_loggers()
parser = colossalai.get_default_parser()
parser.add_argument("--from_torch", default=False, action="store_true")
args = parser.parse_args()
if args.from_torch:
colossalai.launch_from_torch(config=args.config, seed=42)
else:
# standard launch
colossalai.launch(
config=args.config,
rank=args.rank,
world_size=args.world_size,
local_rank=args.local_rank,
host=args.host,
port=args.port,
seed=42,
)
use_zero = hasattr(gpc.config, 'zero')
ctx = contextlib.nullcontext()
if use_zero:
ctx = ZeroInitContext(target_device=torch.cuda.current_device(),
shard_strategy=gpc.config.zero.model_config.shard_strategy,
shard_param=True
)
logger = get_dist_logger()
if hasattr(gpc.config, "LOG_PATH"):
log_path = gpc.config.LOG_PATH
logger.log_to_file(log_path)
assert hasattr(gpc.config, "BATCH_SIZE"), "Please provide BATCH_SIZE in your configuration"
assert hasattr(gpc.config, "SEQ_LENGTH"), "Please provide SEQ_LENGTH in your configuration"
with ctx:
model = build_model()
model = AutoregressiveWrapper(model)
numel, _ = calc_model_size(model)
if numel < 1e9:
msg = f"{numel / 1e6:.3f} M"
else:
msg = f"{numel / 1e9:.3f} B"
model_mem = torch.cuda.max_memory_allocated(get_current_device()) / 1024**3
logger.info("Model is built.", ranks=[0])
logger.info(f"Parameter size = {msg} | Model memory = {model_mem:.3f} GB.", ranks=[0])
criterion = build_loss()
logger.info("Loss is built.", ranks=[0])
if hasattr(gpc.config, 'optimizer'):
optimizer = gpc.config.optimizer.pop('type')(model.parameters(), **gpc.config.optimizer)
else:
optimizer = torch.optim.AdamW(model.parameters(), lr=0.01, weight_decay=1e-2)
logger.info("Optimizer is built.", ranks=[0])
train_dataloader, test_dataloader = build_data(
dataset_path=os.environ["DATA"],
tokenizer_path=os.environ["TOKENIZER"],
seq_len=gpc.config.SEQ_LENGTH,
batch_size=gpc.config.BATCH_SIZE,
)
engine, train_dataloader, test_dataloader, _ = colossalai.initialize(
model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
)
def batch_data_process_func(batch_data):
data = batch_data["input_ids"]
labels = batch_data["labels"]
return data, labels
engine.schedule.batch_data_process_func = batch_data_process_func
timer = MultiTimer()
trainer = Trainer(engine=engine, logger=logger, timer=timer)
hook_list = []
logger.info("Training start.", ranks=[0])
trainer.fit(
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
epochs=gpc.config.NUM_EPOCHS,
max_steps=10,
hooks=hook_list,
return_output_label=False,
display_progress=True,
)
logger.info("Training complete.", ranks=[0])
if __name__ == "__main__":
train_palm()