-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathtutorial-struct_pruning.html
312 lines (224 loc) · 18.7 KB
/
tutorial-struct_pruning.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
<!DOCTYPE html>
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
<head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<link rel="shortcut icon" href="img/favicon.ico">
<title>Pruning Filters and Channels - Neural Network Distiller</title>
<link href='https://fonts.googleapis.com/css?family=Lato:400,700|Roboto+Slab:400,700|Inconsolata:400,700' rel='stylesheet' type='text/css'>
<link rel="stylesheet" href="css/theme.css" type="text/css" />
<link rel="stylesheet" href="css/theme_extra.css" type="text/css" />
<link rel="stylesheet" href="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/github.min.css">
<link href="extra.css" rel="stylesheet">
<script>
// Current page data
var mkdocs_page_name = "Pruning Filters and Channels";
var mkdocs_page_input_path = "tutorial-struct_pruning.md";
var mkdocs_page_url = null;
</script>
<script src="js/jquery-2.1.1.min.js" defer></script>
<script src="js/modernizr-2.8.3.min.js" defer></script>
<script src="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/highlight.min.js"></script>
<script>hljs.initHighlightingOnLoad();</script>
</head>
<body class="wy-body-for-nav" role="document">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side stickynav">
<div class="wy-side-nav-search">
<a href="index.html" class="icon icon-home"> Neural Network Distiller</a>
<div role="search">
<form id ="rtd-search-form" class="wy-form" action="./search.html" method="get">
<input type="text" name="q" placeholder="Search docs" title="Type search term here" />
</form>
</div>
</div>
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<ul class="current">
<li class="toctree-l1">
<a class="" href="index.html">Home</a>
</li>
<li class="toctree-l1">
<a class="" href="install.html">Installation</a>
</li>
<li class="toctree-l1">
<a class="" href="usage.html">Usage</a>
</li>
<li class="toctree-l1">
<a class="" href="schedule.html">Compression Scheduling</a>
</li>
<li class="toctree-l1">
<span class="caption-text">Compressing Models</span>
<ul class="subnav">
<li class="">
<a class="" href="pruning.html">Pruning</a>
</li>
<li class="">
<a class="" href="regularization.html">Regularization</a>
</li>
<li class="">
<a class="" href="quantization.html">Quantization</a>
</li>
<li class="">
<a class="" href="knowledge_distillation.html">Knowledge Distillation</a>
</li>
<li class="">
<a class="" href="conditional_computation.html">Conditional Computation</a>
</li>
</ul>
</li>
<li class="toctree-l1">
<span class="caption-text">Algorithms</span>
<ul class="subnav">
<li class="">
<a class="" href="algo_pruning.html">Pruning</a>
</li>
<li class="">
<a class="" href="algo_quantization.html">Quantization</a>
</li>
<li class="">
<a class="" href="algo_earlyexit.html">Early Exit</a>
</li>
</ul>
</li>
<li class="toctree-l1">
<a class="" href="model_zoo.html">Model Zoo</a>
</li>
<li class="toctree-l1">
<a class="" href="jupyter.html">Jupyter Notebooks</a>
</li>
<li class="toctree-l1">
<a class="" href="design.html">Design</a>
</li>
<li class="toctree-l1">
<span class="caption-text">Tutorials</span>
<ul class="subnav">
<li class=" current">
<a class="current" href="tutorial-struct_pruning.html">Pruning Filters and Channels</a>
<ul class="subnav">
<li class="toctree-l3"><a href="#pruning-filters-channels">Pruning Filters & Channels</a></li>
<ul>
<li><a class="toctree-l4" href="#introduction">Introduction</a></li>
<li><a class="toctree-l4" href="#filter-pruning">Filter Pruning</a></li>
<li><a class="toctree-l4" href="#channel-pruning">Channel Pruning</a></li>
</ul>
</ul>
</li>
<li class="">
<a class="" href="tutorial-lang_model.html">Pruning a Language Model</a>
</li>
</ul>
</li>
</ul>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
<nav class="wy-nav-top" role="navigation" aria-label="top navigation">
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="index.html">Neural Network Distiller</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="wy-breadcrumbs">
<li><a href="index.html">Docs</a> »</li>
<li>Tutorials »</li>
<li>Pruning Filters and Channels</li>
<li class="wy-breadcrumbs-aside">
</li>
</ul>
<hr/>
</div>
<div role="main">
<div class="section">
<h1 id="pruning-filters-channels">Pruning Filters & Channels</h1>
<h2 id="introduction">Introduction</h2>
<p>Channel and filter pruning are examples of structured-pruning which create compressed models that do not require special hardware to execute. This latter fact makes this form of structured pruning particularly interesting and popular.
In networks that have serial data dependencies, it is pretty straight-forward to understand and define how to prune channels and filters. However, in more complex models, with parallel-data dependencies (paths) - such as ResNets (skip connections) and GoogLeNet (Inception layers) – things become increasingly more complex and require a deeper understanding of the data flow in the model, in order to define the pruning schedule.<br />
This post explains channel and filter pruning, the challenges, and how to define a Distiller pruning schedule for these structures. The details of the implementation are left for a separate post.</p>
<p>Before we dive into pruning, let’s level-set on the terminology, because different people (and even research papers) do not always agree on the nomenclature. This reflects my understanding of the nomenclature, and therefore these are the names used in Distiller. I’ll restrict this discussion to Convolution layers in CNNs, to contain the scope of the topic I’ll be covering, although Distiller supports pruning of other structures such as matrix columns and rows.
PyTorch describes <a href="https://pytorch.org/docs/stable/nn.html#conv2d"><code>torch.nn.Conv2d</code></a> as applying “a 2D convolution over an input signal composed of several input planes.” We call each of these input planes a <strong>feature-map</strong> (or FM, for short). Another name is <strong>input channel</strong>, as in the R/G/B channels of an image. Some people refer to feature-maps as <strong>activations</strong> (i.e. the activation of neurons), although I think strictly speaking <strong>activations</strong> are the output of an activation layer that was fed a group of feature-maps. Because it is very common, and because the use of an activation is orthogonal to our discussion, I will use <strong>activations</strong> to refer to the output of a Convolution layer (i.e. 3D stack of feature-maps).</p>
<p>In the PyTorch documentation Convolution outputs have shape (N, C<sub>out</sub>, H<sub>out</sub>, W<sub>out</sub>) where N is a batch size, C<sub>out</sub> denotes a number of output channels, H<sub>out</sub> is a height of output planes in pixels, and W<sub>out</sub> is width in pixels. We won’t be paying much attention to the batch-size since it’s not important to our discussion, so without loss of generality we can set N=1. I’m also assuming the most common Convolutions having <code>groups==1</code>.
Convolution weights are 4D: (F, C, K, K) where F is the number of filters, C is the number of channels, and K is the kernel size (we can assume the kernel height and width are equal for simplicity). A <strong>kernel</strong> is a 2D matrix (K, K) that is part of a 3D feature detector. This feature detector is called a <strong>filter</strong> and it is basically a stack of 2D <strong>kernels</strong>. Each kernel is convolved with a 2D input channel (i.e. feature-map) so if there are C<sub>in</sub> channels in the input, then there are C<sub>in</sub> kernels in a filter (C == C<sub>in</sub>). Each filter is convolved with the entire input to create a single output channel (i.e. feature-map). If there are C<sub>out</sub> output channels, then there are C<sub>out</sub> filters (F == C<sub>out</sub>).</p>
<h2 id="filter-pruning">Filter Pruning</h2>
<p>Filter pruning and channel pruning are very similar, and I’ll expand on that similarity later on – but for now let’s focus on filter pruning.<br />
In filter pruning we use some criterion to determine which filters are <strong>important</strong> and which are not. Researchers came up with all sorts of pruning criteria: the L1-magnitude of the filters (citation), the entropy of the activations (citation), and the classification accuracy reduction (citation) are just some examples. Disregarding how we chose the filters to prune, let’s imagine that in the diagram below, we chose to prune (remove) the green and orange filters (the circle with the “*” designates a Convolution operation).</p>
<p>Since we have two less filters operating on the input, we must have two less output feature-maps. So when we prune filters, besides changing the physical size of the weight tensors, we also need to reconfigure the immediate Convolution layer (change its <code>out_channels</code>) and the following Convolution layer (change its <code>in_channels</code>). And finally, because the next layer’s input is now smaller (has fewer channels), we should also shrink the next layer’s weights tensors, by removing the channels corresponding to the filters we pruned. We say that there is a <strong>data-dependency</strong> between the two Convolution layers. I didn’t make any mention of the activation function that usually follows Convolution, because these functions are parameter-less and are not sensitive to the shape of their input.
There are some other dependencies that Distiller resolves (such as Optimizer parameters tightly-coupled to the weights) that I won’t discuss here, because they are implementation details.
<center><img alt="Example 1" src="imgs/pruning_structs_ex1.png" /></center></p>
<p>The scheduler YAML syntax for this example is pasted below. We use L1-norm ranking of weight filters, and the pruning-rate is set by the AGP algorithm (Automatic Gradual Pruning). The Convolution layers are conveniently named <code>conv1</code> and <code>conv2</code> in this example.</p>
<pre><code>pruners:
example_pruner:
class: L1RankedStructureParameterPruner_AGP
initial_sparsity : 0.10
final_sparsity: 0.50
group_type: Filters
weights: [module.conv1.weight]
</code></pre>
<p>Now let’s add a Batch Normalization layer between the two convolutions:
<center><img alt="Example 2" src="imgs/pruning_structs_ex2.png" /></center></p>
<p>The Batch Normalization layer is parameterized by a couple of tensors that contain information per input-channel (i.e. scale and shift). Because our Convolution produces less output FMs, and these are the input to the Batch Normalization layer, we also need to reconfigure the Batch Normalization layer. And we also need to physically shrink the Batch Normalization layer’s scale and shift tensors, which are coefficients in the BN input transformation. Moreover, the scale and shift coefficients that we remove from the tensors, must correspond to the filters (or output feature-maps channels) that we removed from the Convolution weight tensors. This small nuance will prove to be a large pain, but we’ll get to that in later examples.
The presence of a Batch Normalization layer in the example above is transparent to us, and in fact, the YAML schedule does not change. Distiller detects the presence of Batch Normalization layers and adjusts their parameters automatically.</p>
<p>Let’s look at another example, with non-serial data-dependencies. Here, the output of <code>conv1</code> is the input for <code>conv2</code> and <code>conv3</code>. This is an example of parallel data-dependency, since both <code>conv2</code> and <code>conv3</code> depend on <code>conv1</code>.
<center><img alt="Example 3" src="imgs/pruning_structs_ex3.png" /></center></p>
<p>Note that the Distiller YAML schedule is unchanged from the previous two examples, since we are still only explicitly pruning the weight filters of <code>conv1</code>. The weight channels of <code>conv2</code> and <code>conv3</code> are pruned implicitly by Distiller in a process called “Thinning” (on which I will expand in a different post).</p>
<p>Next, let’s look at another example also involving three Convolutions, but this time we want to prune the filters of two convolutional layers, whose outputs are element-wise-summed and fed into a third Convolution.
In this example <code>conv3</code> is dependent on both <code>conv1</code> and <code>conv2</code>, and there are two implications to this dependency. The first, and more obvious implication, is that we need to prune the same number of filters from both <code>conv1</code> and <code>conv2</code>. Since we apply element-wise addition on the outputs of <code>conv1</code> and <code>conv2</code>, they must have the same shape - and they can only have the same shape if <code>conv1</code> and <code>conv2</code> prune the same number of filters. The second implication of this triangular data-dependency is that both <code>conv1</code> and <code>conv2</code> must prune the <strong>same</strong> filters! Let’s imagine for a moment, that we ignore this second constraint. The diagram below illustrates the dilemma that arises: how should we prune the channels of the weights of <code>conv3</code>? Obviously, we can’t.
<center><img alt="Example 4" src="imgs/pruning_structs_ex4.png" /></center></p>
<p>We must apply the second constraint – and that means that we now need to be proactive: we need to decide whether to use the prune <code>conv1</code> and <code>conv2</code> according to the filter-pruning choices of <code>conv1</code> or of <code>conv2</code>. The diagram below illustrates the pruning scheme after deciding to follow the pruning choices of <code>conv1</code>.
<center><img alt="Example 5" src="imgs/pruning_structs_ex5.png" /></center></p>
<p>The YAML compression schedule syntax needs to be able to express the two dependencies (or constraints) discussed above. First we need to tell the Filter Pruner that we there is a dependency of type <strong>Leader</strong>. This means that all of the tensors listed in the <code>weights</code> field are pruned together, to the same extent at each iteration, and that to prune the filters we will use the pruning decisions of the first tensor listed. In the example below <code>module.conv1.weight</code> and <code>module.conv2.weight</code> are pruned together according to the pruning choices for <code>module.conv1.weight</code>.</p>
<pre><code>pruners:
example_pruner:
class: L1RankedStructureParameterPruner_AGP
initial_sparsity : 0.10
final_sparsity: 0.50
group_type: Filters
group_dependency: Leader
weights: [module.conv1.weight, module.conv2.weight]
</code></pre>
<p>When we turn to filter-pruning ResNets we see some pretty long dependency chains because of the skip-connections. If you don’t pay attention, you can easily under-specify (or mis-specify) dependency chains and Distiller will exit with an exception. The exception does not explain the specification error and this needs to be improved.</p>
<h2 id="channel-pruning">Channel Pruning</h2>
<p>Channel pruning is very similar to Filter pruning with all the details of dependencies reversed. Look again at example #1, but this time imagine that we’ve changed our schedule to prune the <strong>channels</strong> of <code>module.conv2.weight</code>.</p>
<pre><code>pruners:
example_pruner:
class: L1RankedStructureParameterPruner_AGP
initial_sparsity : 0.10
final_sparsity: 0.50
group_type: Channels
weights: [module.conv2.weight]
</code></pre>
<p>As the diagram shows, <code>conv1</code> is now dependent on <code>conv2</code> and its weights filters will be implicitly pruned according to the channels removed from the weights of <code>conv2</code>.
<center><img alt="Example 1" src="imgs/pruning_structs_ex1.png" /></center></p>
<p>Geek On.</p>
</div>
</div>
<footer>
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="tutorial-lang_model.html" class="btn btn-neutral float-right" title="Pruning a Language Model">Next <span class="icon icon-circle-arrow-right"></span></a>
<a href="design.html" class="btn btn-neutral" title="Design"><span class="icon icon-circle-arrow-left"></span> Previous</a>
</div>
<hr/>
<div role="contentinfo">
<!-- Copyright etc -->
</div>
Built with <a href="http://www.mkdocs.org">MkDocs</a> using a <a href="https://github.com/snide/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<div class="rst-versions" role="note" style="cursor: pointer">
<span class="rst-current-version" data-toggle="rst-current-version">
<span><a href="design.html" style="color: #fcfcfc;">« Previous</a></span>
<span style="margin-left: 15px"><a href="tutorial-lang_model.html" style="color: #fcfcfc">Next »</a></span>
</span>
</div>
<script>var base_url = '.';</script>
<script src="js/theme.js" defer></script>
<script src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS_HTML" defer></script>
<script src="search/main.js" defer></script>
</body>
</html>