-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathknowledge_distillation.html
291 lines (201 loc) · 15.9 KB
/
knowledge_distillation.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
<!DOCTYPE html>
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
<head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<link rel="shortcut icon" href="img/favicon.ico">
<title>Knowledge Distillation - Neural Network Distiller</title>
<link href='https://fonts.googleapis.com/css?family=Lato:400,700|Roboto+Slab:400,700|Inconsolata:400,700' rel='stylesheet' type='text/css'>
<link rel="stylesheet" href="css/theme.css" type="text/css" />
<link rel="stylesheet" href="css/theme_extra.css" type="text/css" />
<link rel="stylesheet" href="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/github.min.css">
<link href="extra.css" rel="stylesheet">
<script>
// Current page data
var mkdocs_page_name = "Knowledge Distillation";
var mkdocs_page_input_path = "knowledge_distillation.md";
var mkdocs_page_url = null;
</script>
<script src="js/jquery-2.1.1.min.js" defer></script>
<script src="js/modernizr-2.8.3.min.js" defer></script>
<script src="//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/highlight.min.js"></script>
<script>hljs.initHighlightingOnLoad();</script>
</head>
<body class="wy-body-for-nav" role="document">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side stickynav">
<div class="wy-side-nav-search">
<a href="index.html" class="icon icon-home"> Neural Network Distiller</a>
<div role="search">
<form id ="rtd-search-form" class="wy-form" action="./search.html" method="get">
<input type="text" name="q" placeholder="Search docs" title="Type search term here" />
</form>
</div>
</div>
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<ul class="current">
<li class="toctree-l1">
<a class="" href="index.html">Home</a>
</li>
<li class="toctree-l1">
<a class="" href="install.html">Installation</a>
</li>
<li class="toctree-l1">
<a class="" href="usage.html">Usage</a>
</li>
<li class="toctree-l1">
<a class="" href="schedule.html">Compression Scheduling</a>
</li>
<li class="toctree-l1">
<span class="caption-text">Compressing Models</span>
<ul class="subnav">
<li class="">
<a class="" href="pruning.html">Pruning</a>
</li>
<li class="">
<a class="" href="regularization.html">Regularization</a>
</li>
<li class="">
<a class="" href="quantization.html">Quantization</a>
</li>
<li class=" current">
<a class="current" href="knowledge_distillation.html">Knowledge Distillation</a>
<ul class="subnav">
<li class="toctree-l3"><a href="#knowledge-distillation">Knowledge Distillation</a></li>
<ul>
<li><a class="toctree-l4" href="#new-hyper-parameters">New Hyper-Parameters</a></li>
<li><a class="toctree-l4" href="#references">References</a></li>
</ul>
</ul>
</li>
<li class="">
<a class="" href="conditional_computation.html">Conditional Computation</a>
</li>
</ul>
</li>
<li class="toctree-l1">
<span class="caption-text">Algorithms</span>
<ul class="subnav">
<li class="">
<a class="" href="algo_pruning.html">Pruning</a>
</li>
<li class="">
<a class="" href="algo_quantization.html">Quantization</a>
</li>
<li class="">
<a class="" href="algo_earlyexit.html">Early Exit</a>
</li>
</ul>
</li>
<li class="toctree-l1">
<a class="" href="model_zoo.html">Model Zoo</a>
</li>
<li class="toctree-l1">
<a class="" href="jupyter.html">Jupyter Notebooks</a>
</li>
<li class="toctree-l1">
<a class="" href="design.html">Design</a>
</li>
<li class="toctree-l1">
<span class="caption-text">Tutorials</span>
<ul class="subnav">
<li class="">
<a class="" href="tutorial-struct_pruning.html">Pruning Filters and Channels</a>
</li>
<li class="">
<a class="" href="tutorial-lang_model.html">Pruning a Language Model</a>
</li>
</ul>
</li>
</ul>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
<nav class="wy-nav-top" role="navigation" aria-label="top navigation">
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="index.html">Neural Network Distiller</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="wy-breadcrumbs">
<li><a href="index.html">Docs</a> »</li>
<li>Compressing Models »</li>
<li>Knowledge Distillation</li>
<li class="wy-breadcrumbs-aside">
</li>
</ul>
<hr/>
</div>
<div role="main">
<div class="section">
<h1 id="knowledge-distillation">Knowledge Distillation</h1>
<p>(For details on how to train a model with knowledge distillation in Distiller, see <a href="schedule.html#knowledge-distillation">here</a>)</p>
<p>Knowledge distillation is model compression method in which a small model is trained to mimic a pre-trained, larger model (or ensemble of models). This training setting is sometimes referred to as "teacher-student", where the large model is the teacher and the small model is the student (we'll be using these terms interchangeably).</p>
<p>The method was first proposed by <a href="#bucila-et-al-2006">Bucila et al., 2006</a> and generalized by <a href="#hinton-et-al-2015">Hinton et al., 2015</a>. The implementation in Distiller is based on the latter publication. Here we'll provide a summary of the method. For more information the reader may refer to the paper (a <a href="https://www.youtube.com/watch?v=EK61htlw8hY">video lecture</a> with <a href="http://www.ttic.edu/dl/dark14.pdf">slides</a> is also available).</p>
<p>In distillation, knowledge is transferred from the teacher model to the student by minimizing a loss function in which the target is the distribution of class probabilities predicted by the teacher model. That is - the output of a softmax function on the teacher model's logits. However, in many cases, this probability distribution has the correct class at a very high probability, with all other class probabilities very close to 0. As such, it doesn't provide much information beyond the ground truth labels already provided in the dataset. To tackle this issue, <a href="#hinton-et-al-2015">Hinton et al., 2015</a> introduced the concept of "softmax temperature". The probability <script type="math/tex">p_i</script> of class <script type="math/tex">i</script> is calculated from the logits <script type="math/tex">z</script> as:</p>
<p>
<script type="math/tex; mode=display">p_i = \frac{exp\left(\frac{z_i}{T}\right)}{\sum_{j} \exp\left(\frac{z_j}{T}\right)}</script>
</p>
<p>where <script type="math/tex">T</script> is the temperature parameter. When <script type="math/tex">T=1</script> we get the standard softmax function. As <script type="math/tex">T</script> grows, the probability distribution generated by the softmax function becomes softer, providing more information as to which classes the teacher found more similar to the predicted class. Hinton calls this the "dark knowledge" embedded in the teacher model, and it is this dark knowledge that we are transferring to the student model in the distillation process. When computing the loss function vs. the teacher's soft targets, we use the same value of <script type="math/tex">T</script> to compute the softmax on the student's logits. We call this loss the "distillation loss".</p>
<p><a href="#hinton-et-al-2015">Hinton et al., 2015</a> found that it is also beneficial to train the distilled model to produce the correct labels (based on the ground truth) in addition to the teacher's soft-labels. Hence, we also calculate the "standard" loss between the student's predicted class probabilities and the ground-truth labels (also called "hard labels/targets"). We dub this loss the "student loss". When calculating the class probabilities for the student loss we use <script type="math/tex">T = 1</script>. </p>
<p>The overall loss function, incorporating both distillation and student losses, is calculated as:</p>
<p>
<script type="math/tex; mode=display">\mathcal{L}(x;W) = \alpha * \mathcal{H}(y, \sigma(z_s; T=1)) + \beta * \mathcal{H}(\sigma(z_t; T=\tau), \sigma(z_s, T=\tau))</script>
</p>
<p>where <script type="math/tex">x</script> is the input, <script type="math/tex">W</script> are the student model parameters, <script type="math/tex">y</script> is the ground truth label, <script type="math/tex">\mathcal{H}</script> is the cross-entropy loss function, <script type="math/tex">\sigma</script> is the softmax function parameterized by the temperature <script type="math/tex">T</script>, and <script type="math/tex">\alpha</script> and <script type="math/tex">\beta</script> are coefficients. <script type="math/tex">z_s</script> and <script type="math/tex">z_t</script> are the logits of the student and teacher respectively.</p>
<p><img alt="Knowledge Distillation" src="imgs/knowledge_distillation.png" /></p>
<h2 id="new-hyper-parameters">New Hyper-Parameters</h2>
<p>In general <script type="math/tex">\tau</script>, <script type="math/tex">\alpha</script> and <script type="math/tex">\beta</script> are hyper parameters.</p>
<p>In their experiments, <a href="#hinton-et-al-2015">Hinton et al., 2015</a> use temperature values ranging from 1 to 20. They note that empirically, when the student model is very small compared to the teacher model, lower temperatures work better. This makes sense if we consider that as we raise the temperature, the resulting soft-labels distribution becomes richer in information, and a very small model might not be able to capture all of this information. However, there's no clear way to predict up front what kind of capacity for information the student model will have.</p>
<p>With regards to <script type="math/tex">\alpha</script> and <script type="math/tex">\beta</script>, <a href="#hinton-et-al-2015">Hinton et al., 2015</a> use a weighted average between the distillation loss and the student loss. That is, <script type="math/tex">\beta = 1 - \alpha</script>. They note that in general, they obtained the best results when setting <script type="math/tex">\alpha</script> to be much smaller than <script type="math/tex">\beta</script> (although in one of their experiments they use <script type="math/tex">\alpha = \beta = 0.5</script>). Other works which utilize knowledge distillation don't use a weighted average. Some set <script type="math/tex">\alpha = 1</script> while leaving <script type="math/tex">\beta</script> tunable, while others don't set any constraints.</p>
<h2 id="combining-with-other-model-compression-techniques"><a name="combining"></a>Combining with Other Model Compression Techniques</h2>
<p>In the "basic" scenario, the smaller (student) model is a pre-defined architecture which just has a smaller number of parameters compared to the teacher model. For example, we could train ResNet-18 by distilling knowledge from ResNet-34. But, a model with smaller capacity can also be obtained by other model compression techniques - sparsification and/or quantization. So, for example, we could train a 4-bit ResNet-18 model with some method using quantization-aware training, and use a distillation loss function as described above. In that case, the teacher model can even be a FP32 ResNet-18 model. Same goes for pruning and regularization.</p>
<p><a href="#tann-et-al-2017">Tann et al., 2017</a>, <a href="#mishra-and-marr-2018">Mishra and Marr, 2018</a> and <a href="#polino-et-al-2018">Polino et al., 2018</a> are some works that combine knowledge distillation with <strong>quantization</strong>. <a href="#theis-et-al-2018">Theis et al., 2018</a> and <a href="#ashok-et-al-2018">Ashok et al., 2018</a> combine distillation with <strong>pruning</strong>.</p>
<h2 id="references">References</h2>
<p><div id="bucila-et-al-2006"></div>
<strong>Cristian Bucila, Rich Caruana, and Alexandru Niculescu-Mizil</strong>. Model Compression. <a href="https://www.cs.cornell.edu/~caruana/compression.kdd06.pdf">KDD, 2006</a></p>
<div id="hinton-et-al-2015"></div>
<p><strong>Geoffrey Hinton, Oriol Vinyals and Jeff Dean</strong>. Distilling the Knowledge in a Neural Network. <a href="https://arxiv.org/abs/1503.02531">arxiv:1503.02531</a></p>
<div id="tann-et-al-2017"></div>
<p><strong>Hokchhay Tann, Soheil Hashemi, Iris Bahar and Sherief Reda</strong>. Hardware-Software Codesign of Accurate, Multiplier-free Deep Neural Networks. <a href="https://arxiv.org/abs/1705.04288">DAC, 2017</a></p>
<div id="mishra-and-marr-2018"></div>
<p><strong>Asit Mishra and Debbie Marr</strong>. Apprentice: Using Knowledge Distillation Techniques To Improve Low-Precision Network Accuracy. <a href="https://openreview.net/forum?id=B1ae1lZRb">ICLR, 2018</a></p>
<div id="polino-et-al-2018"></div>
<p><strong>Antonio Polino, Razvan Pascanu and Dan Alistarh</strong>. Model compression via distillation and quantization. <a href="https://openreview.net/forum?id=S1XolQbRW">ICLR, 2018</a></p>
<div id="ashok-et-al-2018"></div>
<p><strong>Anubhav Ashok, Nicholas Rhinehart, Fares Beainy and Kris M. Kitani</strong>. N2N learning: Network to Network Compression via Policy Gradient Reinforcement Learning. <a href="https://openreview.net/forum?id=B1hcZZ-AW">ICLR, 2018</a></p>
<div id="theis-et-al-2018"></div>
<p><strong>Lucas Theis, Iryna Korshunova, Alykhan Tejani and Ferenc Huszár</strong>. Faster gaze prediction with dense networks and Fisher pruning. <a href="https://arxiv.org/abs/1801.05787">arxiv:1801.05787</a></p>
</div>
</div>
<footer>
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="conditional_computation.html" class="btn btn-neutral float-right" title="Conditional Computation">Next <span class="icon icon-circle-arrow-right"></span></a>
<a href="quantization.html" class="btn btn-neutral" title="Quantization"><span class="icon icon-circle-arrow-left"></span> Previous</a>
</div>
<hr/>
<div role="contentinfo">
<!-- Copyright etc -->
</div>
Built with <a href="http://www.mkdocs.org">MkDocs</a> using a <a href="https://github.com/snide/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<div class="rst-versions" role="note" style="cursor: pointer">
<span class="rst-current-version" data-toggle="rst-current-version">
<span><a href="quantization.html" style="color: #fcfcfc;">« Previous</a></span>
<span style="margin-left: 15px"><a href="conditional_computation.html" style="color: #fcfcfc">Next »</a></span>
</span>
</div>
<script>var base_url = '.';</script>
<script src="js/theme.js" defer></script>
<script src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS_HTML" defer></script>
<script src="search/main.js" defer></script>
</body>
</html>