-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_with_engine.py
122 lines (102 loc) · 4.56 KB
/
train_with_engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import glob
import os
import colossalai
import torch
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.nn import CrossEntropyLoss
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.utils import is_using_pp
from titans.dataloader.imagenet import build_dali_imagenet
from colossalai.pipeline.pipelinable import PipelinableContext
from titans.model.vit.vit import _create_vit_model
from tqdm import tqdm
def train_imagenet():
args = colossalai.get_default_parser().parse_args()
# standard launch
# colossalai.launch_from_slurm(config=args.config,
# host=args.host,
# port=29500)
# if using torch distributed launcher
colossalai.launch_from_torch(config=args.config)
logger = get_dist_logger()
if hasattr(gpc.config, 'LOG_PATH'):
if gpc.get_global_rank() == 0:
log_path = gpc.config.LOG_PATH
if not os.path.exists(log_path):
os.mkdir(log_path)
logger.log_to_file(log_path)
use_pipeline = is_using_pp()
# create model
model_kwargs = dict(img_size=gpc.config.IMG_SIZE,
patch_size=gpc.config.PATCH_SIZE,
hidden_size=gpc.config.HIDDEN_SIZE,
depth=gpc.config.DEPTH,
num_heads=gpc.config.NUM_HEADS,
mlp_ratio=gpc.config.MLP_RATIO,
num_classes=gpc.config.NUM_CLASSES,
init_method='jax',
checkpoint=gpc.config.CHECKPOINT)
if use_pipeline:
pipelinable = PipelinableContext()
with pipelinable:
model = _create_vit_model(**model_kwargs)
pipelinable.to_layer_list()
pipelinable.policy = "uniform"
model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
else:
model = _create_vit_model(**model_kwargs)
# count number of parameters
total_numel = 0
for p in model.parameters():
total_numel += p.numel()
if not gpc.is_initialized(ParallelMode.PIPELINE):
pipeline_stage = 0
else:
pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE)
logger.info(f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}")
# create dataloaders
root = os.environ['DATA']
train_dataloader, test_dataloader = build_dali_imagenet(root, train_batch_size=gpc.config.BATCH_SIZE, \
test_batch_size=gpc.config.BATCH_SIZE)
# create loss function
criterion = CrossEntropyLoss(label_smoothing=0.1)
# create optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
# create lr scheduler
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
total_steps=gpc.config.NUM_EPOCHS,
warmup_steps=gpc.config.WARMUP_EPOCHS)
# initialize
engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dataloader,
test_dataloader=test_dataloader)
logger.info("Engine is built", ranks=[0])
# create schedule
schedule = None
tensor_shape = getattr(gpc.config, 'TENSOR_SHAPE', None)
if gpc.is_initialized(ParallelMode.PARALLEL_1D):
scatter_gather = True
else:
scatter_gather = False
data_iter = iter(train_dataloader)
for epoch in range(gpc.config.NUM_EPOCHS):
# training
engine.train()
if gpc.get_global_rank() == 0:
description = 'Epoch {} / {}'.format(epoch, gpc.config.NUM_EPOCHS)
progress = tqdm(range(len(train_dataloader)), desc=description)
else:
progress = range(len(train_dataloader))
for _ in progress:
engine.zero_grad()
engine.execute_schedule(data_iter, return_output_label=False)
engine.step()
lr_scheduler.step()
if __name__ == '__main__':
train_imagenet()