forked from microsoft/varuna
-
Notifications
You must be signed in to change notification settings - Fork 0
/
varuna.html
306 lines (282 loc) · 21.8 KB
/
varuna.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN"
"http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
<html xmlns="http://www.w3.org/1999/xhtml">
<head>
<meta http-equiv="Content-Type" content="text/html; charset=utf-8" />
<title>The Varuna class — Varuna documentation</title>
<link rel="stylesheet" href="_static/alabaster.css" type="text/css" />
<link rel="stylesheet" href="_static/pygments.css" type="text/css" />
<script type="text/javascript">
var DOCUMENTATION_OPTIONS = {
URL_ROOT: './',
VERSION: '',
COLLAPSE_INDEX: false,
FILE_SUFFIX: '.html',
HAS_SOURCE: true,
SOURCELINK_SUFFIX: '.txt'
};
</script>
<script type="text/javascript" src="_static/jquery.js"></script>
<script type="text/javascript" src="_static/underscore.js"></script>
<script type="text/javascript" src="_static/doctools.js"></script>
<link rel="index" title="Index" href="genindex.html" />
<link rel="search" title="Search" href="search.html" />
<link rel="next" title="Profiling for Varuna" href="profiler.html" />
<link rel="prev" title="CutPoints" href="cutpoint.html" />
<link rel="stylesheet" href="_static/custom.css" type="text/css" />
<meta name="viewport" content="width=device-width, initial-scale=0.9, maximum-scale=0.9" />
</head>
<body>
<div class="document">
<div class="documentwrapper">
<div class="bodywrapper">
<div class="body" role="main">
<div class="section" id="the-varuna-class">
<h1>The Varuna class<a class="headerlink" href="#the-varuna-class" title="Permalink to this headline">¶</a></h1>
<p>The torch.nn.Module object for your DNN model should be wrapped in a <a class="reference internal" href="#varuna.Varuna" title="varuna.Varuna"><code class="xref py py-class docutils literal"><span class="pre">Varuna</span></code></a> instance for training.
This class extends torch.nn.Module and handles distributed pipeline & data parallelism, mixed precision and
shared parameter weights internally.</p>
<p>Wrapping in <a class="reference internal" href="#varuna.Varuna" title="varuna.Varuna"><code class="xref py py-class docutils literal"><span class="pre">Varuna</span></code></a> partitions the model into pipeline stages across the distributed job.
For this, it uses stage allocation information that is passed by <code class="docutils literal"><span class="pre">varuna.launcher</span></code> to all
worker processes. The launcher uses a string argument <code class="docutils literal"><span class="pre">stage_to_rank_map</span></code> which must be parsed
and used for <a class="reference internal" href="#varuna.Varuna" title="varuna.Varuna"><code class="xref py py-class docutils literal"><span class="pre">Varuna</span></code></a> initialisation. (see <a class="reference internal" href="launching.html"><span class="doc">Launching Varuna</span></a>)</p>
<p>For profiling and automatic partitioning, Varuna needs sample inputs.
For this, a <code class="docutils literal"><span class="pre">get_batch_fn</span></code> needs to be passed during initialisation which returns a sample input batch
of a given size. This is used to profile the model’s computation graph and should return</p>
<blockquote>
<div>a dictionary of keywords to args, similar to the <code class="docutils literal"><span class="pre">step</span></code> function.</div></blockquote>
<p>The model passed to <a class="reference internal" href="#varuna.Varuna" title="varuna.Varuna"><code class="xref py py-class docutils literal"><span class="pre">Varuna</span></code></a> should be on CPU. Once the profiling and partitioning are done,
the model is moved to the assigned GPU. So the user need not do <code class="docutils literal"><span class="pre">model.cuda()</span></code> anywhere.</p>
<p>Optimizer creation should be after wrapping in <a class="reference internal" href="#varuna.Varuna" title="varuna.Varuna"><code class="xref py py-class docutils literal"><span class="pre">Varuna</span></code></a>, since it requires model parameters as input.
The optimizer needs to be registered with Varuna using a setter.</p>
<p>Example:</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">model</span> <span class="o">=</span> <span class="n">MyModel</span><span class="p">()</span> <span class="c1"># full model on CPU</span>
<span class="k">def</span> <span class="nf">get_batch_fn</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
<span class="n">batch</span> <span class="o">=</span> <span class="n">dataset</span><span class="p">[:</span><span class="n">size</span><span class="p">]</span>
<span class="k">if</span> <span class="n">device</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">:</span>
<span class="n">batch</span> <span class="o">=</span> <span class="p">[</span><span class="n">t</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">batch</span><span class="p">]</span>
<span class="n">inputs</span><span class="p">,</span> <span class="n">mask</span> <span class="o">=</span> <span class="n">batch</span>
<span class="k">return</span> <span class="p">{</span><span class="s1">'inputs'</span><span class="p">:</span> <span class="n">inputs</span><span class="p">,</span> <span class="s1">'mask'</span><span class="p">:</span> <span class="n">mask</span><span class="p">,</span> <span class="s1">'extra_norm'</span><span class="p">:</span> <span class="bp">True</span> <span class="p">}</span>
<span class="c1"># parameter sharing across the model, marked as pairs of param_names</span>
<span class="n">shared_weights</span> <span class="o">=</span> <span class="p">[(</span><span class="s2">"language_model.embedding.word_embeddings.weight"</span><span class="p">,</span><span class="s2">"lm_head_weight"</span><span class="p">)]</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">Varuna</span><span class="p">(</span> <span class="n">model</span><span class="p">,</span> <span class="n">args</span><span class="o">.</span><span class="n">stage_to_rank_map</span><span class="p">,</span> <span class="n">get_batch_fn</span><span class="p">,</span> <span class="n">global_batch_size</span><span class="p">,</span>
<span class="n">args</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">,</span> <span class="n">args</span><span class="o">.</span><span class="n">fp16</span><span class="p">,</span> <span class="n">local_rank</span><span class="o">=</span><span class="n">args</span><span class="o">.</span><span class="n">local_rank</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="n">args</span><span class="o">.</span><span class="n">local_rank</span><span class="p">,</span> <span class="n">shared_weights</span><span class="o">=</span><span class="n">shared_weights</span><span class="p">)</span>
<span class="c1"># now model is a subset of the original model, moved to the GPU on each process</span>
<span class="n">optimizer</span> <span class="o">=</span> <span class="n">get_optimizer</span><span class="p">(</span><span class="n">model</span><span class="p">)</span>
<span class="n">model</span><span class="o">.</span><span class="n">set_optimizer</span><span class="p">(</span><span class="n">optimizer</span><span class="p">)</span>
</pre></div>
</div>
<dl class="class">
<dt id="varuna.Varuna">
<em class="property">class </em><code class="descclassname">varuna.</code><code class="descname">Varuna</code><span class="sig-paren">(</span><em>model</em>, <em>stage_to_rank_map</em>, <em>get_batch_fn</em>, <em>batch_size</em>, <em>chunk_size</em>, <em>fp16=False</em>, <em>local_rank=-1</em>, <em>device=-1</em>, <em>shared_weights=None</em>, <em>from_cache=True</em><span class="sig-paren">)</span><a class="headerlink" href="#varuna.Varuna" title="Permalink to this definition">¶</a></dt>
<dd><p>Module to implement varuna training. The model must be wrapped in an instance
of <code class="docutils literal"><span class="pre">Varuna</span></code> before training. This should be done before optimizer creation and the
<code class="xref py py-attr docutils literal"><span class="pre">model</span></code> passed should be on CPU.</p>
<p>Creating a <code class="docutils literal"><span class="pre">Varuna</span></code> instance profiles the model briefly using <code class="xref py py-attr docutils literal"><span class="pre">dummy_inputs</span></code>
and partitions it according to the distributed rank and launcher arguments.
The partitioned model is then moved to the allocated cuda device. The profiling
information is cached and can be re-used on resuming, unless <code class="xref py py-attr docutils literal"><span class="pre">from_cache</span></code> is False.
The <code class="docutils literal"><span class="pre">Varuna</span></code> module performs mixed precision training internally if enabled through the
<code class="xref py py-attr docutils literal"><span class="pre">fp16</span></code> arg, no external handling is required.</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name" />
<col class="field-body" />
<tbody valign="top">
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple">
<li><strong>model</strong> (<em>torch.nn.Module</em>) – The model to initialize for training.</li>
<li><strong>stage_to_rank_map</strong> (<em>dict</em>) – Placement of pipeline stages in the distribued job, encoded as a string.
Passed by <code class="docutils literal"><span class="pre">varuna.launcher</span></code> to each worker as an argument.</li>
<li><strong>get_batch_fn</strong> (<em>function</em><em>(</em><em>size: int</em><em>, </em><em>device: torch.device</em><em> or </em><em>None</em><em>)</em>) – Function to get sample input batches of a given size, as dictionaries.
These are used to profile the model structure as <code class="docutils literal"><span class="pre">model(**get_batch_fn(k,</span> <span class="pre">device='cpu))</span></code>.</li>
<li><strong>batch_size</strong> (<em>int</em>) – Global batch size for the distributed training job.</li>
<li><strong>chunk_size</strong> (<em>int</em>) – The micro-batch size to be used for pipeline parallelism.</li>
<li><strong>fp16</strong> (<em>bool</em>) – whether to enable mixed precision training.</li>
<li><strong>local_rank</strong> (<em>int</em>) – The local rank as passed by <code class="docutils literal"><span class="pre">varuna.launcher</span></code>. If not given,
defaults to the global rank.</li>
<li><strong>device</strong> (<em>int</em>) – index of the cuda device to use. Recommended to be the same as local_rank,
which is the default if not specified.</li>
<li><strong>shared_weights</strong> (<em>list</em><em> or </em><em>None</em>) – A list of tuples, where each each tuple is a pair of weight names (strings),
such that the two weights are shared in the model (see weight sharing)</li>
<li><strong>from_cache</strong> (<em>bool</em>) – Whether to use cached profiling information if available.</li>
</ul>
</td>
</tr>
</tbody>
</table>
<div class="admonition note">
<p class="first admonition-title">Note</p>
<p class="last">Optimizer initiliastion should be done after <code class="docutils literal"><span class="pre">Varuna</span></code> initialisation, so that the <code class="docutils literal"><span class="pre">param_group</span></code> s
for the optimizer only contain parameters from the partitioned model. This is important both for memory
usage and correctness of fp16 training. Once <code class="docutils literal"><span class="pre">Varuna</span></code> and the optimizer are initialised, <a class="reference internal" href="#varuna.Varuna.set_optimizer" title="varuna.Varuna.set_optimizer"><code class="xref py py-func docutils literal"><span class="pre">set_optimizer()</span></code></a>
should be called to connect the two.</p>
</div>
<dl class="method">
<dt id="varuna.Varuna.set_optimizer">
<code class="descname">set_optimizer</code><span class="sig-paren">(</span><em>optimizer</em>, <em>loss_scale='dynamic'</em>, <em>init_loss_scale=1048576</em>, <em>min_loss_scale=1.0</em><span class="sig-paren">)</span><a class="headerlink" href="#varuna.Varuna.set_optimizer" title="Permalink to this definition">¶</a></dt>
<dd><p>Configure optimizer for training. if <code class="docutils literal"><span class="pre">fp16</span></code> is enabled, this function
initializes the mixed precision state in apex.</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name" />
<col class="field-body" />
<tbody valign="top">
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple">
<li><strong>optimizer</strong> (<em>torch.nn.Optimizer</em>) – the optimizer for training.</li>
<li><strong>loss_scale</strong> (<em>float</em><em> or </em><em>"dynamic"</em><em>, </em><em>optional</em>) – A floating point number for a static loss scale
or the string “dynamic” for dynamic loss scaling.</li>
<li><strong>init_loss_scale</strong> (<em>float</em><em>, </em><em>optional</em>) – Initial loss scale (for dynamic scaling)</li>
<li><strong>min_loss_scale</strong> (<em>float</em><em>, </em><em>optional</em>) – minimum loss scale (for dynamic scaling)</li>
</ul>
</td>
</tr>
</tbody>
</table>
</dd></dl>
<dl class="method">
<dt id="varuna.Varuna.step">
<code class="descname">step</code><span class="sig-paren">(</span><em>inputs</em>, <em>clip_grad_max_norm=None</em><span class="sig-paren">)</span><a class="headerlink" href="#varuna.Varuna.step" title="Permalink to this definition">¶</a></dt>
<dd><p>Perform a single training step. Executes forward and backward passes for
the global batch. This function must be called by all distributed workers in the training loop.
After this function, the optimizer gradients are reduced accross data parallel replicas and
overflow is checked for mixed precision training. Returns average loss and a boolean for overflow.</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name" />
<col class="field-body" />
<tbody valign="top">
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first simple">
<li><strong>inputs</strong> (<em>dict</em>) – The inputs to the model as a dictionary. These should be coordinated amongst workers -
the global batch is sharded across data parallel replicas, so each worker should have
<code class="docutils literal"><span class="pre">global_batch_size</span> <span class="pre">/</span> <span class="pre">data_parallel_depth</span></code> number of examples. And all pipeline stages of the same
data parallel replica should recieve the same inputs.</li>
<li><strong>clip_grad_max_norm</strong> (<em>float</em><em> or </em><em>None</em><em>, </em><em>optional</em>) – If given, the L2 gradient norm of the entire model
is clipped to this upper bound.</li>
</ul>
</td>
</tr>
<tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body"><p class="first">A tuple of the form (average_loss, overflow)</p>
</td>
</tr>
<tr class="field-odd field"><th class="field-name">Return type:</th><td class="field-body"><p class="first last">tuple[float, bool]</p>
</td>
</tr>
</tbody>
</table>
</dd></dl>
<dl class="method">
<dt id="varuna.Varuna.checkpoint">
<code class="descname">checkpoint</code><span class="sig-paren">(</span><em>global_store</em>, <em>step=None</em>, <em>tempdir=None</em>, <em>shard=False</em>, <em>on_demand=False</em><span class="sig-paren">)</span><a class="headerlink" href="#varuna.Varuna.checkpoint" title="Permalink to this definition">¶</a></dt>
<dd><p>Writes a varuna checkpoint with model parameters, optimizer state etc.
Each checkpoint is a directory, written under the given path.</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name" />
<col class="field-body" />
<tbody valign="top">
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple">
<li><strong>global_store</strong> (<em>dict</em>) – path to a folder accessible by all nodes/ranks in the training job.
For example, path to a mounted blob storage. This is where the varuna checkpoint folder is written.</li>
<li><strong>step</strong> (<em>int</em><em> or </em><em>None</em><em>, </em><em>optional</em>) – iteration number for checkpoint. If None, it’ll be taken from varuna’s tracked progress.</li>
<li><strong>tempdir</strong> (<em>str</em><em>, </em><em>optional</em>) – path to a local directory to which to write checkpoints temporarily, and sync
with the global store in the background. Lowers checkpoint write time in the critical path.</li>
<li><strong>shard</strong> (<em>bool</em><em>, </em><em>optional</em>) – whether to shard checkpoint writes over data parallel workers as well. Speeds up checkpoint</li>
</ul>
</td>
</tr>
</tbody>
</table>
</dd></dl>
<dl class="method">
<dt id="varuna.Varuna.load_checkpoint">
<code class="descname">load_checkpoint</code><span class="sig-paren">(</span><em>global_store</em>, <em>iteration</em>, <em>check_complete=True</em><span class="sig-paren">)</span><a class="headerlink" href="#varuna.Varuna.load_checkpoint" title="Permalink to this definition">¶</a></dt>
<dd><p>Loads a varuna checkpoint from a shared directory. Each varuna checkpoint is a directory
named as “varuna_ckpt_<iteration>”. So the path under which all such checkpoints were written
should be specified.</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name" />
<col class="field-body" />
<tbody valign="top">
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first last simple">
<li><strong>global_store</strong> (<em>str</em>) – path under which varuna checkpoints were written.
Should be accessible by all workers.</li>
<li><strong>iteration</strong> (<em>int</em>) – Which iteration checkpoint to load.</li>
<li><strong>check_complete</strong> (<em>bool</em><em>, </em><em>optional</em>) – Check that the checkpoint is complete before loading it.
A checkpoint can be incomplete if the write was interrupted.</li>
</ul>
</td>
</tr>
</tbody>
</table>
</dd></dl>
<dl class="method">
<dt id="varuna.Varuna.evaluate">
<code class="descname">evaluate</code><span class="sig-paren">(</span><em>inputs</em>, <em>batch_size=None</em><span class="sig-paren">)</span><a class="headerlink" href="#varuna.Varuna.evaluate" title="Permalink to this definition">¶</a></dt>
<dd><p>Evaluate the model on the given inputs. This must be called on all workers
because it uses pipeline & data parallelism. Inputs should be for the respective data parallel replica
and have <code class="docutils literal"><span class="pre">batch_size/data_parallel_depth</span></code> examples, similar to <a class="reference internal" href="#varuna.Varuna.step" title="varuna.Varuna.step"><code class="xref py py-func docutils literal"><span class="pre">step()</span></code></a>.
Returns loss averaged over all workers.</p>
<table class="docutils field-list" frame="void" rules="none">
<col class="field-name" />
<col class="field-body" />
<tbody valign="top">
<tr class="field-odd field"><th class="field-name">Parameters:</th><td class="field-body"><ul class="first simple">
<li><strong>inputs</strong> (<em>dict</em>) – Model inputs as dictionary. The number of examples
for these inputs should be the same as the batch_size defined for training.</li>
<li><strong>batch_size</strong> (<em>int</em><em>, </em><em>optional</em>) – Batch size for evaluation, if not given it’s the same as training batch size.</li>
</ul>
</td>
</tr>
<tr class="field-even field"><th class="field-name">Returns:</th><td class="field-body"><p class="first">average loss</p>
</td>
</tr>
<tr class="field-odd field"><th class="field-name">Return type:</th><td class="field-body"><p class="first last">float</p>
</td>
</tr>
</tbody>
</table>
</dd></dl>
</dd></dl>
</div>
</div>
</div>
</div>
<div class="sphinxsidebar" role="navigation" aria-label="main navigation">
<div class="sphinxsidebarwrapper"><div class="relations">
<h3>Related Topics</h3>
<ul>
<li><a href="index.html">Documentation overview</a><ul>
<li>Previous: <a href="cutpoint.html" title="previous chapter">CutPoints</a></li>
<li>Next: <a href="profiler.html" title="next chapter">Profiling for Varuna</a></li>
</ul></li>
</ul>
</div>
<div role="note" aria-label="source link">
<h3>This Page</h3>
<ul class="this-page-menu">
<li><a href="_sources/varuna.rst.txt"
rel="nofollow">Show Source</a></li>
</ul>
</div>
<div id="searchbox" style="display: none" role="search">
<h3>Quick search</h3>
<form class="search" action="search.html" method="get">
<div><input type="text" name="q" /></div>
<div><input type="submit" value="Go" /></div>
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
<script type="text/javascript">$('#searchbox').show(0);</script>
</div>
</div>
<div class="clearer"></div>
</div>
<div class="footer">
©2021, Nitika Saran.
|
Powered by <a href="http://sphinx-doc.org/">Sphinx 1.6.7</a>
& <a href="https://github.com/bitprophet/alabaster">Alabaster 0.7.8</a>
|
<a href="_sources/varuna.rst.txt"
rel="nofollow">Page source</a>
</div>
</body>
</html>