-
Notifications
You must be signed in to change notification settings - Fork 3.5k
/
Copy pathtest_pruning.py
345 lines (291 loc) · 13.7 KB
/
test_pruning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from collections import OrderedDict
from logging import INFO
from typing import Union
import pytest
import torch
import torch.nn.utils.prune as pytorch_prune
from torch import nn
from torch.nn import Sequential
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, ModelPruning
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf
class TestModel(BoringModel):
test_step = None
def __init__(self):
super().__init__()
self.layer = Sequential(
OrderedDict(
[("mlp_1", nn.Linear(32, 32)), ("mlp_2", nn.Linear(32, 32, bias=False)), ("mlp_3", nn.Linear(32, 2))]
)
)
def training_step(self, batch, batch_idx):
self.log("test", -batch_idx)
return super().training_step(batch, batch_idx)
class TestPruningMethod(pytorch_prune.BasePruningMethod):
PRUNING_TYPE = "unstructured"
def compute_mask(self, _, default_mask):
mask = default_mask.clone()
# Prune every other entry in a tensor
mask.view(-1)[::2] = 0
return mask
@classmethod
def apply(cls, module, name, amount):
return super().apply(module, name, amount=amount)
def train_with_pruning_callback(
tmpdir,
parameters_to_prune=False,
use_global_unstructured=False,
pruning_fn="l1_unstructured",
use_lottery_ticket_hypothesis=False,
accelerator=None,
gpus=None,
num_processes=1,
):
model = TestModel()
# Weights are random. None is 0
assert torch.all(model.layer.mlp_2.weight != 0)
pruning_kwargs = {
"pruning_fn": pruning_fn,
"amount": 0.3,
"use_global_unstructured": use_global_unstructured,
"use_lottery_ticket_hypothesis": use_lottery_ticket_hypothesis,
"verbose": 1,
}
if parameters_to_prune:
pruning_kwargs["parameters_to_prune"] = [(model.layer.mlp_1, "weight"), (model.layer.mlp_2, "weight")]
else:
if isinstance(pruning_fn, str) and pruning_fn.endswith("_structured"):
pruning_kwargs["parameter_names"] = ["weight"]
else:
pruning_kwargs["parameter_names"] = ["weight", "bias"]
if isinstance(pruning_fn, str) and pruning_fn.endswith("_structured"):
pruning_kwargs["pruning_dim"] = 0
if pruning_fn == "ln_structured":
pruning_kwargs["pruning_norm"] = 1
# Misconfiguration checks
if isinstance(pruning_fn, str) and pruning_fn.endswith("_structured") and use_global_unstructured:
with pytest.raises(MisconfigurationException, match="is supported with `use_global_unstructured=True`"):
ModelPruning(**pruning_kwargs)
return
if ModelPruning._is_pruning_method(pruning_fn) and not use_global_unstructured:
with pytest.raises(MisconfigurationException, match="currently only supported with"):
ModelPruning(**pruning_kwargs)
return
pruning = ModelPruning(**pruning_kwargs)
trainer = Trainer(
default_root_dir=tmpdir,
enable_progress_bar=False,
weights_summary=None,
enable_checkpointing=False,
logger=False,
limit_train_batches=10,
limit_val_batches=2,
max_epochs=10,
accelerator=accelerator,
gpus=gpus,
num_processes=num_processes,
callbacks=pruning,
)
trainer.fit(model)
trainer.test(model)
if not accelerator:
# Check some have been pruned
assert torch.any(model.layer.mlp_2.weight == 0)
def test_pruning_misconfiguration():
with pytest.raises(MisconfigurationException, match=r"chocolate isn't in \('weight', 'bias'\)"):
ModelPruning(pruning_fn="l1_unstructured", parameter_names=["chocolate"])
with pytest.raises(MisconfigurationException, match=r"expected to be a str in \["):
ModelPruning(pruning_fn={})
with pytest.raises(MisconfigurationException, match="should be provided"):
ModelPruning(pruning_fn="random_structured")
with pytest.raises(MisconfigurationException, match=r"must be any of \(0, 1, 2\)"):
ModelPruning(pruning_fn="l1_unstructured", verbose=3)
with pytest.raises(MisconfigurationException, match="requesting `ln_structured` pruning, the `pruning_norm`"):
ModelPruning(pruning_fn="ln_structured", pruning_dim=0)
@pytest.mark.parametrize("parameters_to_prune", [False, True])
@pytest.mark.parametrize("use_global_unstructured", [False, True])
@pytest.mark.parametrize(
"pruning_fn", ["l1_unstructured", "random_unstructured", "ln_structured", "random_structured", TestPruningMethod]
)
@pytest.mark.parametrize("use_lottery_ticket_hypothesis", [False, True])
def test_pruning_callback(
tmpdir,
use_global_unstructured: bool,
parameters_to_prune: bool,
pruning_fn: Union[str, pytorch_prune.BasePruningMethod],
use_lottery_ticket_hypothesis: bool,
):
train_with_pruning_callback(
tmpdir,
parameters_to_prune=parameters_to_prune,
use_global_unstructured=use_global_unstructured,
pruning_fn=pruning_fn,
use_lottery_ticket_hypothesis=use_lottery_ticket_hypothesis,
)
@RunIf(special=True, min_gpus=2)
def test_pruning_callback_ddp_0(tmpdir):
train_with_pruning_callback(
tmpdir, parameters_to_prune=False, use_global_unstructured=False, accelerator="ddp", gpus=2
)
@RunIf(special=True, min_gpus=2)
def test_pruning_callback_ddp_1(tmpdir):
train_with_pruning_callback(
tmpdir, parameters_to_prune=False, use_global_unstructured=True, accelerator="ddp", gpus=2
)
@RunIf(special=True, min_gpus=2)
def test_pruning_callback_ddp_2(tmpdir):
train_with_pruning_callback(
tmpdir, parameters_to_prune=True, use_global_unstructured=False, accelerator="ddp", gpus=2
)
@RunIf(special=True, min_gpus=2)
def test_pruning_callback_ddp_3(tmpdir):
train_with_pruning_callback(
tmpdir, parameters_to_prune=True, use_global_unstructured=True, accelerator="ddp", gpus=2
)
@RunIf(min_gpus=2, skip_windows=True)
def test_pruning_callback_ddp_spawn(tmpdir):
train_with_pruning_callback(tmpdir, use_global_unstructured=True, accelerator="ddp_spawn", gpus=2)
@RunIf(skip_windows=True)
def test_pruning_callback_ddp_cpu(tmpdir):
train_with_pruning_callback(tmpdir, parameters_to_prune=True, accelerator="ddp_cpu", num_processes=2)
@pytest.mark.parametrize("resample_parameters", (False, True))
def test_pruning_lth_callable(tmpdir, resample_parameters: bool):
model = TestModel()
class ModelPruningTestCallback(ModelPruning):
lth_calls = 0
def apply_lottery_ticket_hypothesis(self):
super().apply_lottery_ticket_hypothesis()
self.lth_calls += 1
for d in self._original_layers.values():
copy, names = d["data"], d["names"]
for i, name in names:
curr, curr_name = self._parameters_to_prune[i]
assert name == curr_name
actual, expected = getattr(curr, name).data, getattr(copy, name).data
allclose = torch.allclose(actual, expected)
assert not allclose if self._resample_parameters else allclose
pruning = ModelPruningTestCallback(
"l1_unstructured", use_lottery_ticket_hypothesis=lambda e: bool(e % 2), resample_parameters=resample_parameters
)
trainer = Trainer(
default_root_dir=tmpdir,
enable_progress_bar=False,
weights_summary=None,
enable_checkpointing=False,
logger=False,
limit_train_batches=10,
limit_val_batches=2,
max_epochs=5,
callbacks=pruning,
)
trainer.fit(model)
assert pruning.lth_calls == trainer.max_epochs // 2
@pytest.mark.parametrize("make_pruning_permanent", (False, True))
def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent: bool):
model = TestModel()
pruning_kwargs = {
"parameters_to_prune": [(model.layer.mlp_1, "weight"), (model.layer.mlp_3, "weight")],
"verbose": 2,
"make_pruning_permanent": make_pruning_permanent,
}
p1 = ModelPruning("l1_unstructured", amount=0.5, apply_pruning=lambda e: not e % 2, **pruning_kwargs)
p2 = ModelPruning("random_unstructured", amount=0.25, apply_pruning=lambda e: e % 2, **pruning_kwargs)
trainer = Trainer(
default_root_dir=tmpdir,
enable_progress_bar=False,
weights_summary=None,
enable_checkpointing=False,
logger=False,
limit_train_batches=10,
limit_val_batches=2,
max_epochs=3,
callbacks=[p1, p2],
)
with caplog.at_level(INFO):
trainer.fit(model)
actual = [m.strip() for m in caplog.messages]
actual = [m for m in actual if m.startswith("Applied")]
percentage = r"\(\d+(?:\.\d+)?%\)"
expected = [
rf"Applied `L1Unstructured`. Pruned: \d+\/1122 {percentage} -> \d+\/1122 {percentage}",
rf"Applied `L1Unstructured` to `Linear\(in_features=32, out_features=32, bias=True\).weight` with amount=0.5. Pruned: 0 \(0.00%\) -> \d+ {percentage}", # noqa: E501
rf"Applied `L1Unstructured` to `Linear\(in_features=32, out_features=2, bias=True\).weight` with amount=0.5. Pruned: 0 \(0.00%\) -> \d+ {percentage}", # noqa: E501
rf"Applied `RandomUnstructured`. Pruned: \d+\/1122 {percentage} -> \d+\/1122 {percentage}",
rf"Applied `RandomUnstructured` to `Linear\(in_features=32, out_features=32, bias=True\).weight` with amount=0.25. Pruned: \d+ {percentage} -> \d+ {percentage}", # noqa: E501
rf"Applied `RandomUnstructured` to `Linear\(in_features=32, out_features=2, bias=True\).weight` with amount=0.25. Pruned: \d+ {percentage} -> \d+ {percentage}", # noqa: E501
rf"Applied `L1Unstructured`. Pruned: \d+\/1122 {percentage} -> \d+\/1122 {percentage}",
rf"Applied `L1Unstructured` to `Linear\(in_features=32, out_features=32, bias=True\).weight` with amount=0.5. Pruned: \d+ {percentage} -> \d+ {percentage}", # noqa: E501
rf"Applied `L1Unstructured` to `Linear\(in_features=32, out_features=2, bias=True\).weight` with amount=0.5. Pruned: \d+ {percentage} -> \d+ {percentage}", # noqa: E501
]
expected = [re.compile(s) for s in expected]
assert all(regex.match(s) for s, regex in zip(actual, expected))
filepath = str(tmpdir / "foo.ckpt")
trainer.save_checkpoint(filepath)
model.load_from_checkpoint(filepath, strict=False)
has_pruning = hasattr(model.layer.mlp_1, "weight_orig")
assert not has_pruning if make_pruning_permanent else has_pruning
@pytest.mark.parametrize("prune_on_train_epoch_end", (False, True))
@pytest.mark.parametrize("save_on_train_epoch_end", (False, True))
def test_permanent_when_model_is_saved_multiple_times(
tmpdir, caplog, prune_on_train_epoch_end, save_on_train_epoch_end
):
"""When a model is saved multiple times and make_permanent=True, we need to make sure a copy is pruned and not
the trained model if we want to continue with the same pruning buffers."""
if prune_on_train_epoch_end and save_on_train_epoch_end:
pytest.xfail(
"Pruning sets the `grad_fn` of the parameters so we can't save"
" right after as pruning has not been made permanent"
)
class TestPruning(ModelPruning):
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
had_buffers = hasattr(pl_module.layer.mlp_3, "weight_orig")
super().on_save_checkpoint(trainer, pl_module, checkpoint)
assert "layer.mlp_3.weight_orig" not in checkpoint["state_dict"]
if had_buffers:
assert hasattr(pl_module.layer.mlp_3, "weight_orig")
model = TestModel()
pruning_callback = TestPruning(
"random_unstructured",
parameters_to_prune=[(model.layer.mlp_3, "weight")],
verbose=1,
make_pruning_permanent=True,
prune_on_train_epoch_end=prune_on_train_epoch_end,
)
ckpt_callback = ModelCheckpoint(
monitor="test", save_top_k=2, save_last=True, save_on_train_epoch_end=save_on_train_epoch_end
)
trainer = Trainer(callbacks=[pruning_callback, ckpt_callback], max_epochs=3, enable_progress_bar=False)
with caplog.at_level(INFO):
trainer.fit(model)
actual = [m.strip() for m in caplog.messages]
actual = [m for m in actual if m.startswith("Applied")]
percentage = r"\(\d+(?:\.\d+)?%\)"
expected = [
rf"Applied `RandomUnstructured`. Pruned: \d+\/66 {percentage} -> \d+\/66 {percentage}",
rf"Applied `RandomUnstructured`. Pruned: \d+\/66 {percentage} -> \d+\/66 {percentage}",
rf"Applied `RandomUnstructured`. Pruned: \d+\/66 {percentage} -> \d+\/66 {percentage}",
]
expected = [re.compile(s) for s in expected]
assert all(regex.match(s) for s, regex in zip(actual, expected))
# removed on_train_end
assert not hasattr(model.layer.mlp_3, "weight_orig")
model.load_from_checkpoint(trainer.checkpoint_callback.kth_best_model_path)
assert not hasattr(model.layer.mlp_3, "weight_orig")
model.load_from_checkpoint(trainer.checkpoint_callback.last_model_path)
assert not hasattr(model.layer.mlp_3, "weight_orig")