forked from microsoft/varuna
-
Notifications
You must be signed in to change notification settings - Fork 0
/
cutpoint.html
157 lines (139 loc) · 11.3 KB
/
cutpoint.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN"
"http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
<html xmlns="http://www.w3.org/1999/xhtml">
<head>
<meta http-equiv="Content-Type" content="text/html; charset=utf-8" />
<title>CutPoints — Varuna documentation</title>
<link rel="stylesheet" href="_static/alabaster.css" type="text/css" />
<link rel="stylesheet" href="_static/pygments.css" type="text/css" />
<script type="text/javascript">
var DOCUMENTATION_OPTIONS = {
URL_ROOT: './',
VERSION: '',
COLLAPSE_INDEX: false,
FILE_SUFFIX: '.html',
HAS_SOURCE: true,
SOURCELINK_SUFFIX: '.txt'
};
</script>
<script type="text/javascript" src="_static/jquery.js"></script>
<script type="text/javascript" src="_static/underscore.js"></script>
<script type="text/javascript" src="_static/doctools.js"></script>
<link rel="index" title="Index" href="genindex.html" />
<link rel="search" title="Search" href="search.html" />
<link rel="next" title="The Varuna class" href="varuna.html" />
<link rel="prev" title="Launching Varuna" href="launching.html" />
<link rel="stylesheet" href="_static/custom.css" type="text/css" />
<meta name="viewport" content="width=device-width, initial-scale=0.9, maximum-scale=0.9" />
</head>
<body>
<div class="document">
<div class="documentwrapper">
<div class="bodywrapper">
<div class="body" role="main">
<div class="section" id="cutpoints">
<h1>CutPoints<a class="headerlink" href="#cutpoints" title="Permalink to this headline">¶</a></h1>
<p>Varuna slices a DNN model into sequential stages for pipeline parallelism.
For this, the model should be annotated with varuna <code class="xref py py-class docutils literal"><span class="pre">CutPoint</span></code> instances between
different operations/ parts of model computation.</p>
<p>A <code class="xref py py-class docutils literal"><span class="pre">CutPoint</span></code> in varuna is an abstraction to mark a potential point of partitioning in your model.
It is implemented as a <code class="xref py py-class docutils literal"><span class="pre">torch.nn.Module</span></code> instance, which is called on the activations at the potential
boundary point. For each <code class="xref py py-class docutils literal"><span class="pre">CutPoint</span></code>, Varuna can either ignore it or activate it as a partition boundary.
CutPoints can be marked anywhere in the model as follows:</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">varuna</span> <span class="kn">import</span> <span class="n">CutPoint</span>
<span class="k">class</span> <span class="nc">SampleModel</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="o">...</span><span class="p">):</span>
<span class="o">....</span>
<span class="bp">self</span><span class="o">.</span><span class="n">cutpoints</span> <span class="o">=</span> <span class="p">[</span><span class="n">CutPoint</span><span class="p">()</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_cutpoints</span><span class="p">)]</span>
<span class="o">....</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="nb">input</span><span class="o">...</span><span class="p">):</span>
<span class="nb">input</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">some_operation</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span>
<span class="nb">input</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">cutpoints</span><span class="p">[</span><span class="mi">0</span><span class="p">](</span><span class="nb">input</span><span class="p">)</span> <span class="c1"># marked as a potential stage boundary</span>
<span class="nb">input</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">some_other_operation</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span>
<span class="o">....</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">sub_modules</span><span class="p">):</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">sub_module_i</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="o">...</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">cutpoints</span><span class="p">[</span><span class="n">i</span><span class="o">+</span><span class="mi">1</span><span class="p">](</span><span class="n">x</span><span class="p">)</span> <span class="c1"># each cutpoint instance should be used only once in a model</span>
<span class="o">....</span>
</pre></div>
</div>
<p>Based on the number of desired pipeline stages, Varuna chooses a subset of the given cutpoints and
activates them as actual boundaries between stages. For example, if the user marks <cite>n</cite> cutpoints in total,
and wants 4 parallel pipeline stages, 3 cutpoints will be activated as partitions between the 4 stages and the rest
<cite>n-3</cite> are treated as they don’t exist.
<br/> With this partitioning, each worker in the distributed job runs a sub-section of the model code between
two activated <code class="xref py py-class docutils literal"><span class="pre">CutPoint</span></code> instances, or between one activated <code class="xref py py-class docutils literal"><span class="pre">CutPoint</span></code> and the
beginning/end of the model.</p>
<p>For an activated <code class="xref py py-class docutils literal"><span class="pre">CutPoint</span></code>, the input to the cutpoint is an intermediate activation in the model
that needs to be passed between sequential stages.</p>
<div class="admonition note">
<p class="first admonition-title">Note</p>
<p class="last">The input to any <code class="xref py py-class docutils literal"><span class="pre">CutPoint</span></code> in the model’s execution should be a single <code class="xref py py-class docutils literal"><span class="pre">torch.Tensor</span></code>
of shape <cite>(b, d2, d3, …)</cite> where <cite>b</cite> is the number of examples in the input to the model.
<br/> This is important because Varuna uses micro-batches to parallelize computation and relies on this
format for communication between pipeline stages.</p>
</div>
<p>Operations separated by CutPoints should preferably have no shared modules/parameters.
For weight sharing between different parts of the module, you should register separate <code class="xref py py-class docutils literal"><span class="pre">nn.Parameter</span></code>
instances (even for the same tensor) and pass the pair of parameter names as <code class="xref py py-attr docutils literal"><span class="pre">shared_weights</span></code> to the
<code class="xref py py-class docutils literal"><span class="pre">Varuna</span></code> object.</p>
<p>For example, in language models like BERT and GPT2, the weights for word embedding computation at
the beginning of the model are also utilised at the end of the model for prediction logits.
So, if this weight is wrapped in two separate <code class="xref py py-class docutils literal"><span class="pre">torch.nn.Parameter</span></code> instances, they will have two
corresponding “parameter names” (string values) in the model (see <code class="xref py py-func docutils literal"><span class="pre">named_parameters()</span></code> for <code class="xref py py-class docutils literal"><span class="pre">torch.nn.Parameter</span></code>).
These can be passed as a pair of names for each shared weight to <code class="xref py py-class docutils literal"><span class="pre">Varuna</span></code> as follows:</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="c1"># list of 2-tuples with parameter names</span>
<span class="n">shared_weights</span> <span class="o">=</span> <span class="p">[(</span><span class="s2">"language_model.embedding.word_embeddings.weight"</span><span class="p">,</span><span class="s2">"lm_head_weight"</span><span class="p">)]</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">Varuna</span><span class="p">(</span> <span class="n">model</span><span class="p">,</span> <span class="n">args</span><span class="o">.</span><span class="n">stage_to_rank_map</span><span class="p">,</span> <span class="n">dry_run_input</span><span class="p">,</span> <span class="n">global_batch_size</span><span class="p">,</span>
<span class="n">args</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">,</span> <span class="n">args</span><span class="o">.</span><span class="n">fp16</span><span class="p">,</span>
<span class="n">local_rank</span><span class="o">=</span><span class="n">args</span><span class="o">.</span><span class="n">local_rank</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="n">args</span><span class="o">.</span><span class="n">local_rank</span><span class="p">,</span>
<span class="n">shared_weights</span><span class="o">=</span><span class="n">shared_weights</span><span class="p">)</span> <span class="c1"># passed to varuna init</span>
</pre></div>
</div>
</div>
</div>
</div>
</div>
<div class="sphinxsidebar" role="navigation" aria-label="main navigation">
<div class="sphinxsidebarwrapper"><div class="relations">
<h3>Related Topics</h3>
<ul>
<li><a href="index.html">Documentation overview</a><ul>
<li>Previous: <a href="launching.html" title="previous chapter">Launching Varuna</a></li>
<li>Next: <a href="varuna.html" title="next chapter">The Varuna class</a></li>
</ul></li>
</ul>
</div>
<div role="note" aria-label="source link">
<h3>This Page</h3>
<ul class="this-page-menu">
<li><a href="_sources/cutpoint.rst.txt"
rel="nofollow">Show Source</a></li>
</ul>
</div>
<div id="searchbox" style="display: none" role="search">
<h3>Quick search</h3>
<form class="search" action="search.html" method="get">
<div><input type="text" name="q" /></div>
<div><input type="submit" value="Go" /></div>
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
<script type="text/javascript">$('#searchbox').show(0);</script>
</div>
</div>
<div class="clearer"></div>
</div>
<div class="footer">
©2021, Nitika Saran.
|
Powered by <a href="http://sphinx-doc.org/">Sphinx 1.6.7</a>
& <a href="https://github.com/bitprophet/alabaster">Alabaster 0.7.8</a>
|
<a href="_sources/cutpoint.rst.txt"
rel="nofollow">Page source</a>
</div>
</body>
</html>