diff --git a/dataset.py b/dataset.py index 7f4cd07..1a4b416 100644 --- a/dataset.py +++ b/dataset.py @@ -12,8 +12,10 @@ import glob import os + def loader(path, batch_size=32, num_workers=4, pin_memory=True): - normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) return data.DataLoader( datasets.ImageFolder(path, transforms.Compose([ @@ -28,8 +30,10 @@ def loader(path, batch_size=32, num_workers=4, pin_memory=True): num_workers=num_workers, pin_memory=pin_memory) + def test_loader(path, batch_size=32, num_workers=4, pin_memory=True): - normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) return data.DataLoader( datasets.ImageFolder(path, transforms.Compose([ diff --git a/finetune.py b/finetune.py index bfed08d..5434243 100644 --- a/finetune.py +++ b/finetune.py @@ -1,3 +1,4 @@ +import copy import torch from torch.autograd import Variable from torchvision import models @@ -15,256 +16,272 @@ from heapq import nsmallest import time + class ModifiedVGG16Model(torch.nn.Module): - def __init__(self): - super(ModifiedVGG16Model, self).__init__() - - model = models.vgg16(pretrained=True) - self.features = model.features - - for param in self.features.parameters(): - param.requires_grad = False - - self.classifier = nn.Sequential( - nn.Dropout(), - nn.Linear(25088, 4096), - nn.ReLU(inplace=True), - nn.Dropout(), - nn.Linear(4096, 4096), - nn.ReLU(inplace=True), - nn.Linear(4096, 2)) - - def forward(self, x): - x = self.features(x) - x = x.view(x.size(0), -1) - x = self.classifier(x) - return x + def __init__(self): + super(ModifiedVGG16Model, self).__init__() + + model = models.vgg16(pretrained=True) + self.features = model.features + + for param in self.features.parameters(): + param.requires_grad = False + + self.classifier = nn.Sequential( + nn.Dropout(), + nn.Linear(25088, 4096), + nn.ReLU(inplace=True), + nn.Dropout(), + nn.Linear(4096, 4096), + nn.ReLU(inplace=True), + nn.Linear(4096, 2)) + + def forward(self, x): + x = self.features(x) + x = x.view(x.size(0), -1) + x = self.classifier(x) + return x + class FilterPrunner: - def __init__(self, model): - self.model = model - self.reset() - - def reset(self): - # self.activations = [] - # self.gradients = [] - # self.grad_index = 0 - # self.activation_to_layer = {} - self.filter_ranks = {} - - def forward(self, x): - self.activations = [] - self.gradients = [] - self.grad_index = 0 - self.activation_to_layer = {} - - activation_index = 0 - for layer, (name, module) in enumerate(self.model.features._modules.items()): - x = module(x) - if isinstance(module, torch.nn.modules.conv.Conv2d): - x.register_hook(self.compute_rank) - self.activations.append(x) - self.activation_to_layer[activation_index] = layer - activation_index += 1 - - return self.model.classifier(x.view(x.size(0), -1)) - - def compute_rank(self, grad): - activation_index = len(self.activations) - self.grad_index - 1 - activation = self.activations[activation_index] - values = \ - torch.sum((activation * grad), dim = 0).\ - sum(dim=2).sum(dim=3)[0, :, 0, 0].data - - # Normalize the rank by the filter dimensions - values = \ - values / (activation.size(0) * activation.size(2) * activation.size(3)) - - if activation_index not in self.filter_ranks: - self.filter_ranks[activation_index] = \ - torch.FloatTensor(activation.size(1)).zero_().cuda() - - self.filter_ranks[activation_index] += values - self.grad_index += 1 - - def lowest_ranking_filters(self, num): - data = [] - for i in sorted(self.filter_ranks.keys()): - for j in range(self.filter_ranks[i].size(0)): - data.append((self.activation_to_layer[i], j, self.filter_ranks[i][j])) - - return nsmallest(num, data, itemgetter(2)) - - def normalize_ranks_per_layer(self): - for i in self.filter_ranks: - v = torch.abs(self.filter_ranks[i]) - v = v / np.sqrt(torch.sum(v * v)) - self.filter_ranks[i] = v.cpu() - - def get_prunning_plan(self, num_filters_to_prune): - filters_to_prune = self.lowest_ranking_filters(num_filters_to_prune) - - # After each of the k filters are prunned, - # the filter index of the next filters change since the model is smaller. - filters_to_prune_per_layer = {} - for (l, f, _) in filters_to_prune: - if l not in filters_to_prune_per_layer: - filters_to_prune_per_layer[l] = [] - filters_to_prune_per_layer[l].append(f) - - for l in filters_to_prune_per_layer: - filters_to_prune_per_layer[l] = sorted(filters_to_prune_per_layer[l]) - for i in range(len(filters_to_prune_per_layer[l])): - filters_to_prune_per_layer[l][i] = filters_to_prune_per_layer[l][i] - i - - filters_to_prune = [] - for l in filters_to_prune_per_layer: - for i in filters_to_prune_per_layer[l]: - filters_to_prune.append((l, i)) - - return filters_to_prune + def __init__(self, model): + self.model = model + self.reset() + + def reset(self): + # self.activations = [] + # self.gradients = [] + # self.grad_index = 0 + # self.activation_to_layer = {} + self.filter_ranks = {} + + def forward(self, x): + self.activations = [] + self.gradients = [] + self.grad_index = 0 + self.activation_to_layer = {} + + activation_index = 0 + for layer, (name, module) in enumerate(self.model.features._modules.items()): + x = module(x) + if isinstance(module, torch.nn.modules.conv.Conv2d): + x.register_hook(self.compute_rank(activation_index)) + self.activations.append(x) + self.activation_to_layer[activation_index] = layer + activation_index += 1 + + return self.model.classifier(x.view(x.size(0), -1)) + + def compute_rank(self, activation_index): + # Returns a partial function + # as the callback function + def hook(grad): + activation = self.activations[activation_index] + values = \ + torch.sum((activation * grad), dim=0).\ + sum(dim=2).sum(dim=3)[0, :, 0, 0].data + + # Normalize the rank by the filter dimensions + values = \ + values / (activation.size(0) * activation.size(2) + * activation.size(3)) + + if activation_index not in self.filter_ranks: + self.filter_ranks[activation_index] = \ + torch.FloatTensor(activation.size(1)).zero_().cuda() + + self.filter_ranks[activation_index] += values + self.grad_index += 1 + return hook + + def lowest_ranking_filters(self, num): + data = [] + for i in sorted(self.filter_ranks.keys()): + for j in range(self.filter_ranks[i].size(0)): + data.append( + (self.activation_to_layer[i], j, self.filter_ranks[i][j])) + + return nsmallest(num, data, itemgetter(2)) + + def normalize_ranks_per_layer(self): + for i in self.filter_ranks: + v = torch.abs(self.filter_ranks[i]) + v = v / np.sqrt(torch.sum(v * v)) + self.filter_ranks[i] = v.cpu() + + def get_prunning_plan(self, num_filters_to_prune): + filters_to_prune = self.lowest_ranking_filters(num_filters_to_prune) + + # After each of the k filters are prunned, + # the filter index of the next filters change since the model is smaller. + filters_to_prune_per_layer = {} + for (l, f, _) in filters_to_prune: + if l not in filters_to_prune_per_layer: + filters_to_prune_per_layer[l] = [] + filters_to_prune_per_layer[l].append(f) + + for l in filters_to_prune_per_layer: + filters_to_prune_per_layer[l] = sorted( + filters_to_prune_per_layer[l]) + for i in range(len(filters_to_prune_per_layer[l])): + filters_to_prune_per_layer[l][i] = filters_to_prune_per_layer[l][i] - i + + filters_to_prune = [] + for l in filters_to_prune_per_layer: + for i in filters_to_prune_per_layer[l]: + filters_to_prune.append((l, i)) + + return filters_to_prune + class PrunningFineTuner_VGG16: - def __init__(self, train_path, test_path, model): - self.train_data_loader = dataset.loader(train_path) - self.test_data_loader = dataset.test_loader(test_path) - - self.model = model - self.criterion = torch.nn.CrossEntropyLoss() - self.prunner = FilterPrunner(self.model) - self.model.train() - - def test(self): - self.model.eval() - correct = 0 - total = 0 - - for i, (batch, label) in enumerate(self.test_data_loader): - batch = batch.cuda() - output = model(Variable(batch)) - pred = output.data.max(1)[1] - correct += pred.cpu().eq(label).sum() - total += label.size(0) - - print "Accuracy :", float(correct) / total - - self.model.train() - - def train(self, optimizer = None, epoches = 10): - if optimizer is None: - optimizer = \ - optim.SGD(model.classifier.parameters(), - lr=0.0001, momentum=0.9) - - for i in range(epoches): - print "Epoch: ", i - self.train_epoch(optimizer) - self.test() - print "Finished fine tuning." - - - def train_batch(self, optimizer, batch, label, rank_filters): - self.model.zero_grad() - input = Variable(batch) - - if rank_filters: - output = self.prunner.forward(input) - self.criterion(output, Variable(label)).backward() - else: - self.criterion(self.model(input), Variable(label)).backward() - optimizer.step() - - def train_epoch(self, optimizer = None, rank_filters = False): - for batch, label in self.train_data_loader: - self.train_batch(optimizer, batch.cuda(), label.cuda(), rank_filters) - - def get_candidates_to_prune(self, num_filters_to_prune): - self.prunner.reset() - - self.train_epoch(rank_filters = True) - - self.prunner.normalize_ranks_per_layer() - - return self.prunner.get_prunning_plan(num_filters_to_prune) - - def total_num_filters(self): - filters = 0 - for name, module in self.model.features._modules.items(): - if isinstance(module, torch.nn.modules.conv.Conv2d): - filters = filters + module.out_channels - return filters - - def prune(self): - #Get the accuracy before prunning - self.test() - - self.model.train() - - #Make sure all the layers are trainable - for param in self.model.features.parameters(): - param.requires_grad = True - - number_of_filters = self.total_num_filters() - num_filters_to_prune_per_iteration = 512 - iterations = int(float(number_of_filters) / num_filters_to_prune_per_iteration) - - iterations = int(iterations * 2.0 / 3) - - print "Number of prunning iterations to reduce 67% filters", iterations - - for _ in range(iterations): - print "Ranking filters.. " - prune_targets = self.get_candidates_to_prune(num_filters_to_prune_per_iteration) - layers_prunned = {} - for layer_index, filter_index in prune_targets: - if layer_index not in layers_prunned: - layers_prunned[layer_index] = 0 - layers_prunned[layer_index] = layers_prunned[layer_index] + 1 - - print "Layers that will be prunned", layers_prunned - print "Prunning filters.. " - model = self.model.cpu() - for layer_index, filter_index in prune_targets: - model = prune_vgg16_conv_layer(model, layer_index, filter_index) - - self.model = model.cuda() - - message = str(100*float(self.total_num_filters()) / number_of_filters) + "%" - print "Filters prunned", str(message) - self.test() - print "Fine tuning to recover from prunning iteration." - optimizer = optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9) - self.train(optimizer, epoches = 10) - - - print "Finished. Going to fine tune the model a bit more" - self.train(optimizer, epoches = 15) - torch.save(model, "model_prunned") + def __init__(self, train_path, test_path, model): + self.train_data_loader = dataset.loader(train_path) + self.test_data_loader = dataset.test_loader(test_path) + + self.model = model + self.criterion = torch.nn.CrossEntropyLoss() + self.prunner = FilterPrunner(self.model) + self.model.train() + + def test(self): + self.model.eval() + correct = 0 + total = 0 + + for i, (batch, label) in enumerate(self.test_data_loader): + batch = batch.cuda() + output = model(Variable(batch)) + pred = output.data.max(1)[1] + correct += pred.cpu().eq(label).sum() + total += label.size(0) + + print("Accuracy :" + float(correct) / total) + + self.model.train() + + def train(self, optimizer=None, epoches=10): + if optimizer is None: + optimizer = \ + optim.SGD(model.classifier.parameters(), + lr=0.0001, momentum=0.9) + + for i in range(epoches): + print("Epoch: ", i) + self.train_epoch(optimizer) + self.test() + print("Finished fine tuning.") + + def train_batch(self, optimizer, batch, label, rank_filters): + self.model.zero_grad() + input = Variable(batch) + + if rank_filters: + output = self.prunner.forward(input) + self.criterion(output, Variable(label)).backward() + else: + self.criterion(self.model(input), Variable(label)).backward() + optimizer.step() + + def train_epoch(self, optimizer=None, rank_filters=False): + for batch, label in self.train_data_loader: + self.train_batch(optimizer, batch.cuda(), + label.cuda(), rank_filters) + + def get_candidates_to_prune(self, num_filters_to_prune): + self.prunner.reset() + + self.train_epoch(rank_filters=True) + + self.prunner.normalize_ranks_per_layer() + + return self.prunner.get_prunning_plan(num_filters_to_prune) + + def total_num_filters(self): + filters = 0 + for name, module in self.model.features._modules.items(): + if isinstance(module, torch.nn.modules.conv.Conv2d): + filters = filters + module.out_channels + return filters + + def prune(self): + # Get the accuracy before prunning + self.test() + + self.model.train() + + # Make sure all the layers are trainable + for param in self.model.features.parameters(): + param.requires_grad = True + + number_of_filters = self.total_num_filters() + num_filters_to_prune_per_iteration = 512 + iterations = int(float(number_of_filters) / + num_filters_to_prune_per_iteration) + + iterations = int(iterations * 2.0 / 3) + + print("Number of prunning iterations to reduce 67% filters", iterations) + + for _ in range(iterations): + print("Ranking filters.. ") + prune_targets = self.get_candidates_to_prune( + num_filters_to_prune_per_iteration) + layers_prunned = {} + for layer_index, filter_index in prune_targets: + if layer_index not in layers_prunned: + layers_prunned[layer_index] = 0 + layers_prunned[layer_index] = layers_prunned[layer_index] + 1 + + print("Layers that will be prunned", layers_prunned) + print("Prunning filters.. ") + model = self.model.cpu() + for layer_index, filter_index in prune_targets: + model = prune_vgg16_conv_layer( + model, layer_index, filter_index) + + self.model = model.cuda() + + message = str(100 * float(self.total_num_filters()) / + number_of_filters) + "%" + print("Filters prunned", str(message)) + self.test() + print("Fine tuning to recover from prunning iteration.") + optimizer = optim.SGD(self.model.parameters(), + lr=0.001, momentum=0.9) + self.train(optimizer, epoches=10) + + print("Finished. Going to fine tune the model a bit more") + self.train(optimizer, epoches=15) + torch.save(model, "model_prunned") + def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--train", dest="train", action="store_true") parser.add_argument("--prune", dest="prune", action="store_true") - parser.add_argument("--train_path", type = str, default = "train") - parser.add_argument("--test_path", type = str, default = "test") + parser.add_argument("--train_path", type=str, default="train") + parser.add_argument("--test_path", type=str, default="test") parser.set_defaults(train=False) parser.set_defaults(prune=False) args = parser.parse_args() return args + if __name__ == '__main__': - args = get_args() + args = get_args() - if args.train: - model = ModifiedVGG16Model().cuda() - elif args.prune: - model = torch.load("model").cuda() + if args.train: + model = ModifiedVGG16Model().cuda() + elif args.prune: + model = torch.load("model").cuda() - fine_tuner = PrunningFineTuner_VGG16(args.train_path, args.test_path, model) + fine_tuner = PrunningFineTuner_VGG16( + args.train_path, args.test_path, model) - if args.train: - fine_tuner.train(epoches = 20) - torch.save(model, "model") + if args.train: + fine_tuner.train(epoches=20) + torch.save(model, "model") - elif args.prune: - fine_tuner.prune() \ No newline at end of file + elif args.prune: + fine_tuner.prune() \ No newline at end of file diff --git a/prune.py b/prune.py index a605388..07952a8 100644 --- a/prune.py +++ b/prune.py @@ -4,125 +4,131 @@ import cv2 import sys import numpy as np - + + def replace_layers(model, i, indexes, layers): - if i in indexes: - return layers[indexes.index(i)] - return model[i] + if i in indexes: + return layers[indexes.index(i)] + return model[i] + def prune_vgg16_conv_layer(model, layer_index, filter_index): - _, conv = model.features._modules.items()[layer_index] - next_conv = None - offset = 1 - - while layer_index + offset < len(model.features._modules.items()): - res = model.features._modules.items()[layer_index+offset] - if isinstance(res[1], torch.nn.modules.conv.Conv2d): - next_name, next_conv = res - break - offset = offset + 1 - - new_conv = \ - torch.nn.Conv2d(in_channels = conv.in_channels, \ - out_channels = conv.out_channels - 1, - kernel_size = conv.kernel_size, \ - stride = conv.stride, - padding = conv.padding, - dilation = conv.dilation, - groups = conv.groups, - bias = conv.bias) - - old_weights = conv.weight.data.cpu().numpy() - new_weights = new_conv.weight.data.cpu().numpy() - - new_weights[: filter_index, :, :, :] = old_weights[: filter_index, :, :, :] - new_weights[filter_index : , :, :, :] = old_weights[filter_index + 1 :, :, :, :] - new_conv.weight.data = torch.from_numpy(new_weights).cuda() - - bias_numpy = conv.bias.data.cpu().numpy() - - bias = np.zeros(shape = (bias_numpy.shape[0] - 1), dtype = np.float32) - bias[:filter_index] = bias_numpy[:filter_index] - bias[filter_index : ] = bias_numpy[filter_index + 1 :] - new_conv.bias.data = torch.from_numpy(bias).cuda() - - if not next_conv is None: - next_new_conv = \ - torch.nn.Conv2d(in_channels = next_conv.in_channels - 1,\ - out_channels = next_conv.out_channels, \ - kernel_size = next_conv.kernel_size, \ - stride = next_conv.stride, - padding = next_conv.padding, - dilation = next_conv.dilation, - groups = next_conv.groups, - bias = next_conv.bias) - - old_weights = next_conv.weight.data.cpu().numpy() - new_weights = next_new_conv.weight.data.cpu().numpy() - - new_weights[:, : filter_index, :, :] = old_weights[:, : filter_index, :, :] - new_weights[:, filter_index : , :, :] = old_weights[:, filter_index + 1 :, :, :] - next_new_conv.weight.data = torch.from_numpy(new_weights).cuda() - - next_new_conv.bias.data = next_conv.bias.data - - if not next_conv is None: - features = torch.nn.Sequential( - *(replace_layers(model.features, i, [layer_index, layer_index+offset], \ - [new_conv, next_new_conv]) for i, _ in enumerate(model.features))) - del model.features - del conv - - model.features = features - - else: - #Prunning the last conv layer. This affects the first linear layer of the classifier. - model.features = torch.nn.Sequential( - *(replace_layers(model.features, i, [layer_index], \ - [new_conv]) for i, _ in enumerate(model.features))) - layer_index = 0 - old_linear_layer = None - for _, module in model.classifier._modules.items(): - if isinstance(module, torch.nn.Linear): - old_linear_layer = module - break - layer_index = layer_index + 1 - - if old_linear_layer is None: - raise BaseException("No linear laye found in classifier") - params_per_input_channel = old_linear_layer.in_features / conv.out_channels - - new_linear_layer = \ - torch.nn.Linear(old_linear_layer.in_features - params_per_input_channel, - old_linear_layer.out_features) - - old_weights = old_linear_layer.weight.data.cpu().numpy() - new_weights = new_linear_layer.weight.data.cpu().numpy() - - new_weights[:, : filter_index * params_per_input_channel] = \ - old_weights[:, : filter_index * params_per_input_channel] - new_weights[:, filter_index * params_per_input_channel :] = \ - old_weights[:, (filter_index + 1) * params_per_input_channel :] - - new_linear_layer.bias.data = old_linear_layer.bias.data - - new_linear_layer.weight.data = torch.from_numpy(new_weights).cuda() - - classifier = torch.nn.Sequential( - *(replace_layers(model.classifier, i, [layer_index], \ - [new_linear_layer]) for i, _ in enumerate(model.classifier))) - - del model.classifier - del next_conv - del conv - model.classifier = classifier - - return model + _, conv = model.features._modules.items()[layer_index] + next_conv = None + offset = 1 + + while layer_index + offset < len(model.features._modules.items()): + res = model.features._modules.items()[layer_index + offset] + if isinstance(res[1], torch.nn.modules.conv.Conv2d): + next_name, next_conv = res + break + offset = offset + 1 + + new_conv = \ + torch.nn.Conv2d(in_channels=conv.in_channels, + out_channels=conv.out_channels - 1, + kernel_size=conv.kernel_size, + stride=conv.stride, + padding=conv.padding, + dilation=conv.dilation, + groups=conv.groups, + bias=conv.bias) + + old_weights = conv.weight.data.cpu().numpy() + new_weights = new_conv.weight.data.cpu().numpy() + + new_weights[: filter_index, :, :, :] = old_weights[: filter_index, :, :, :] + new_weights[filter_index:, :, :, + :] = old_weights[filter_index + 1:, :, :, :] + new_conv.weight.data = torch.from_numpy(new_weights).cuda() + + bias_numpy = conv.bias.data.cpu().numpy() + + bias = np.zeros(shape=(bias_numpy.shape[0] - 1), dtype=np.float32) + bias[:filter_index] = bias_numpy[:filter_index] + bias[filter_index:] = bias_numpy[filter_index + 1:] + new_conv.bias.data = torch.from_numpy(bias).cuda() + + if not next_conv is None: + next_new_conv = \ + torch.nn.Conv2d(in_channels=next_conv.in_channels - 1, + out_channels=next_conv.out_channels, + kernel_size=next_conv.kernel_size, + stride=next_conv.stride, + padding=next_conv.padding, + dilation=next_conv.dilation, + groups=next_conv.groups, + bias=next_conv.bias) + + old_weights = next_conv.weight.data.cpu().numpy() + new_weights = next_new_conv.weight.data.cpu().numpy() + + new_weights[:, : filter_index, :, + :] = old_weights[:, : filter_index, :, :] + new_weights[:, filter_index:, :, + :] = old_weights[:, filter_index + 1:, :, :] + next_new_conv.weight.data = torch.from_numpy(new_weights).cuda() + + next_new_conv.bias.data = next_conv.bias.data + + if not next_conv is None: + features = torch.nn.Sequential( + *(replace_layers(model.features, i, [layer_index, layer_index + offset], + [new_conv, next_new_conv]) for i, _ in enumerate(model.features))) + del model.features + del conv + + model.features = features + + else: + # Prunning the last conv layer. This affects the first linear layer of the classifier. + model.features = torch.nn.Sequential( + *(replace_layers(model.features, i, [layer_index], + [new_conv]) for i, _ in enumerate(model.features))) + layer_index = 0 + old_linear_layer = None + for _, module in model.classifier._modules.items(): + if isinstance(module, torch.nn.Linear): + old_linear_layer = module + break + layer_index = layer_index + 1 + + if old_linear_layer is None: + raise BaseException("No linear laye found in classifier") + params_per_input_channel = old_linear_layer.in_features / conv.out_channels + + new_linear_layer = \ + torch.nn.Linear(old_linear_layer.in_features - params_per_input_channel, + old_linear_layer.out_features) + + old_weights = old_linear_layer.weight.data.cpu().numpy() + new_weights = new_linear_layer.weight.data.cpu().numpy() + + new_weights[:, : filter_index * params_per_input_channel] = \ + old_weights[:, : filter_index * params_per_input_channel] + new_weights[:, filter_index * params_per_input_channel:] = \ + old_weights[:, (filter_index + 1) * params_per_input_channel:] + + new_linear_layer.bias.data = old_linear_layer.bias.data + + new_linear_layer.weight.data = torch.from_numpy(new_weights).cuda() + + classifier = torch.nn.Sequential( + *(replace_layers(model.classifier, i, [layer_index], + [new_linear_layer]) for i, _ in enumerate(model.classifier))) + + del model.classifier + del next_conv + del conv + model.classifier = classifier + + return model + if __name__ == '__main__': - model = models.vgg16(pretrained=True) - model.train() + model = models.vgg16(pretrained=True) + model.train() - t0 = time.time() - model = prune_conv_layer(model, 28, 10) - print "The prunning took", time.time() - t0 \ No newline at end of file + t0 = time.time() + model = prune_conv_layer(model, 28, 10) + print("The prunning took ", time.time() - t0) \ No newline at end of file