From 5b8de726524d9b77970bd9c37bb1a00bee69f7a9 Mon Sep 17 00:00:00 2001 From: Ryuk Date: Wed, 23 Feb 2022 22:45:58 +0800 Subject: [PATCH] fix Issues#44 and add support for dropout layer --- dump_percepnet.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/dump_percepnet.py b/dump_percepnet.py index e01715b..7e90675 100755 --- a/dump_percepnet.py +++ b/dump_percepnet.py @@ -26,7 +26,7 @@ import torch import sys import rnn_train -from torch.nn import Sequential, GRU, Conv1d, Linear +from torch.nn import Sequential, GRU, Conv1d, Linear, Dropout import numpy as np def printVector(f, vector, name, dtype='float'): @@ -53,7 +53,7 @@ def dump_sequential_module(self, f, name): self[0].dump_data(f,name,activation) Sequential.dump_data = dump_sequential_module -def dump_linear_module(self, f, name, activation): +def dump_linear_module(self, f, name, activation='LINEAR'): print("printing layer " + name) weight = self.weight bias = self.bias @@ -144,7 +144,8 @@ def dump_conv1d_module(self, f, name, activation): f.write('#ifdef HAVE_CONFIG_H\n#include "config.h"\n#endif\n\n#include "nnet.h"\n#include "nnet_data.h"\n\n') for name, module in model.named_children(): - module.dump_data(f, name) + if "dropout" not in name: + module.dump_data(f, name) f.write('extern const RNNModel percepnet_model_orig = {\n') for name, module in model.named_children(): @@ -152,4 +153,4 @@ def dump_conv1d_module(self, f, name, activation): f.write('};\n') f.close() - print("done") \ No newline at end of file + print("done")