forked from microsoft/varuna
-
Notifications
You must be signed in to change notification settings - Fork 0
/
apex.patch
135 lines (128 loc) · 6.13 KB
/
apex.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
diff --git a/apex/amp/_process_optimizer.py b/apex/amp/_process_optimizer.py
index 471289b..3e38de0 100644
--- a/apex/amp/_process_optimizer.py
+++ b/apex/amp/_process_optimizer.py
@@ -351,10 +351,10 @@ def _process_optimizer(optimizer, properties):
_master_params_to_model_params, optimizer)
old_step = optimizer.step
- def new_step(self, closure=None):
+ def new_step(self, global_grad_norm=-1, closure=None):
if closure is not None:
raise RuntimeError("Currently, Amp does not support closure use with optimizers.")
- retval = old_step()
+ retval = old_step(global_grad_norm=global_grad_norm)
if not isinstance(self, FusedSGD):
self._master_params_to_model_params()
# Clear the master grads that wouldn't be zeroed by model.zero_grad()
diff --git a/apex/amp/handle.py b/apex/amp/handle.py
index 0be567c..5d844fd 100644
--- a/apex/amp/handle.py
+++ b/apex/amp/handle.py
@@ -19,7 +19,8 @@ def scale_loss(loss,
loss_id=0,
model=None,
delay_unscale=False,
- delay_overflow_check=False):
+ delay_overflow_check=False,
+ last_partition=True):
"""
On context manager entrance, creates ``scaled_loss = (loss.float())*current loss scale``.
``scaled_loss`` is yielded so that the user can call ``scaled_loss.backward()``::
@@ -110,7 +111,10 @@ def scale_loss(loss,
if not optimizer._amp_stash.params_have_scaled_gradients:
optimizer._prepare_amp_backward()
- yield (loss.float())*loss_scale
+ if last_partition:
+ yield (loss.float())*loss_scale
+ else:
+ yield loss.float()
if delay_unscale:
for optimizer in optimizers:
diff --git a/apex/amp/scaler.py b/apex/amp/scaler.py
index 99888bc..63f5457 100644
--- a/apex/amp/scaler.py
+++ b/apex/amp/scaler.py
@@ -205,6 +205,7 @@ class LossScaler(object):
self._loss_scale = max(self._min_loss_scale, self._loss_scale/2.)
else:
self._loss_scale = self._loss_scale/2.
+ print(torch.distributed.get_rank(), ': update_scale(): _has_overflow, dynamic. _loss_scale = ', self._loss_scale)
self._unskipped = 0
else:
should_skip = False
diff --git a/apex/optimizers/fused_lamb.py b/apex/optimizers/fused_lamb.py
index 854525d..c8bcd70 100644
--- a/apex/optimizers/fused_lamb.py
+++ b/apex/optimizers/fused_lamb.py
@@ -93,7 +93,7 @@ class FusedLAMB(torch.optim.Optimizer):
else:
super(FusedLAMB, self).zero_grad()
- def step(self, closure=None):
+ def step(self, global_grad_norm=-1, closure=None):
"""Performs a single optimization step.
Arguments:
@@ -104,36 +104,37 @@ class FusedLAMB(torch.optim.Optimizer):
if closure is not None:
loss = closure()
- # create separate grad lists for fp32 and fp16 params
- g_all_32, g_all_16 = [], []
- for group in self.param_groups:
- for p in group['params']:
- if p.grad is None:
- continue
- if p.dtype == torch.float32:
- g_all_32.append(p.grad.data)
- elif p.dtype == torch.float16:
- g_all_16.append(p.grad.data)
- else:
- raise RuntimeError('FusedLAMB only support fp16 and fp32.')
-
- device = self.param_groups[0]["params"][0].device
- g_norm_32, g_norm_16 = torch.zeros(1, device=device), torch.zeros(1, device=device)
- # compute grad norm for two lists
- if len(g_all_32) > 0:
- g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm,
- self._dummy_overflow_buf,
- [g_all_32], False)[0]
- if len(g_all_16) > 0:
- g_norm_16 = multi_tensor_applier(self.multi_tensor_l2norm,
- self._dummy_overflow_buf,
- [g_all_16], False)[0]
-
- # blend two grad norms to get global grad norm
- global_grad_norm = multi_tensor_applier(self.multi_tensor_l2norm,
+ if global_grad_norm == -1:
+ # create separate grad lists for fp32 and fp16 params
+ g_all_32, g_all_16 = [], []
+ for group in self.param_groups:
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ if p.dtype == torch.float32:
+ g_all_32.append(p.grad.data)
+ elif p.dtype == torch.float16:
+ g_all_16.append(p.grad.data)
+ else:
+ raise RuntimeError('FusedLAMB only support fp16 and fp32.')
+
+ device = self.param_groups[0]["params"][0].device
+ g_norm_32, g_norm_16 = torch.zeros(1, device=device), torch.zeros(1, device=device)
+ # compute grad norm for two lists
+ if len(g_all_32) > 0:
+ g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf,
- [[g_norm_32, g_norm_16]],
- False)[0]
+ [g_all_32], False)[0]
+ if len(g_all_16) > 0:
+ g_norm_16 = multi_tensor_applier(self.multi_tensor_l2norm,
+ self._dummy_overflow_buf,
+ [g_all_16], False)[0]
+
+ # blend two grad norms to get global grad norm
+ global_grad_norm = multi_tensor_applier(self.multi_tensor_l2norm,
+ self._dummy_overflow_buf,
+ [[g_norm_32, g_norm_16]],
+ False)[0]
max_grad_norm = self.defaults['max_grad_norm']
for group in self.param_groups: