Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update layer norm options #191

Merged
merged 7 commits into from
Jun 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 15 additions & 20 deletions sru/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,11 +350,13 @@ def extra_repr(self):
s += ", activation={activation}"
if self.v1:
s += ", v1={v1}"
s += ", rescale={rescale}"
if self.rescale:
s += ", rescale={rescale}"
if not self.has_skip_term:
s += ", has_skip_term={has_skip_term}"
if self.layer_norm:
s += ", layer_norm=True"
s += ", normalize_after={normalize_after}"
s += ",\n transform_module=" + str(self.transform_module)
return s.format(**self.__dict__)

Expand Down Expand Up @@ -958,7 +960,8 @@ def __init__(self,
num_heads: int = 1,
bidirectional: bool = False,
layer_norm: bool = False,
normalization_type: int = 1,
normalize_after: bool = False,
attn_layer_norm: bool = True,
highway_bias: float = -2.0,
attention_every_n_layers: int = 1,
attention_last_n_layers: int = -1,
Expand Down Expand Up @@ -988,10 +991,12 @@ def __init__(self,
if True, use bidirectional SRU++ (default=False).
layer_norm: bool, optional
whether to apply layer normalization to each SRU++ layer (default=False).
normalization_type: int, optional
which type of layer normalization to apply. 1: apply normalization after attention.
2: apply normalization before any operators. 3: apply normalization after all operators.
(default=1)
normalize_after: bool, optional
whether to apply post layer norm that normalizes the output of each SRU++ layer
(default=False).
attn_layer_norm: bool, optional
whether to apply layer norm in the attention module or projected linear module if
attention is disabled (default=True).
highway_bias: float, optional
the initial value of the bias used in the highway (sigmoid) gate (default=-1.0).
attention_every_n_layers: int, optional
Expand All @@ -1007,10 +1012,6 @@ def __init__(self,
(default=1.0)

"""
if normalization_type > 3 or normalization_type < 1:
raise ValueError("normalization_type={} but expect 1, 2 or 3.".format(
normalization_type
))
if attention_every_n_layers != 1 and attention_last_n_layers != -1:
raise ValueError("Cannot set both attention_every_n_layers and "
"attention_last_n_layers in SRU++ module.")
Expand All @@ -1025,7 +1026,6 @@ def __init__(self,
self.rnn_lst = nn.ModuleList()
self.bidirectional = bidirectional
self.use_layer_norm = layer_norm
self.normalization_type = normalization_type
self.num_directions = 2 if bidirectional else 1
self.nn_rnn_compatible_return = nn_rnn_compatible_return
self.input_to_hidden: Optional[nn.Module] = None
Expand All @@ -1036,11 +1036,6 @@ def __init__(self,
else:
first_layer_input_size = input_size

# layer norm configuration
module_layer_norm = normalization_type == 1
cell_layer_norm = normalization_type != 1
cell_normalize_after = normalization_type == 3

# attention configuration
if attention_last_n_layers != -1:
use_attention = lambda ind: num_layers - ind <= attention_last_n_layers # noqa
Expand All @@ -1061,23 +1056,23 @@ def __init__(self,
dropout=dropout,
attn_dropout=attn_dropout,
num_heads=num_heads,
layer_norm=module_layer_norm,
layer_norm=attn_layer_norm,
)
else:
custom_m = SRUppProjectedLinear(
in_features,
out_features,
proj_features,
dropout=dropout,
layer_norm=module_layer_norm,
layer_norm=attn_layer_norm,
)
layer = SRUppCell(
in_features,
self.hidden_size,
dropout=dropout if i + 1 != num_layers else 0,
bidirectional=bidirectional,
layer_norm=cell_layer_norm,
normalize_after=cell_normalize_after,
layer_norm=layer_norm,
normalize_after=normalize_after,
highway_bias=highway_bias,
rescale=rescale,
transform_module=custom_m,
Expand Down
2 changes: 1 addition & 1 deletion sru/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '3.0.0.dev3'
__version__ = '3.0.0.dev6'
16 changes: 11 additions & 5 deletions test/sru/test_sru.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ def test_srupp_creation(attn_every_n_layers, expected_transform_module):
@pytest.mark.parametrize("compat", [False, True])
@pytest.mark.parametrize("bidirectional", [False, True])
@pytest.mark.parametrize("layer_norm", [False, True])
def test_srupp(cuda, with_grad, compat, bidirectional, layer_norm):
@pytest.mark.parametrize("normalize_after", [False, True])
def test_srupp(cuda, with_grad, compat, bidirectional, layer_norm, normalize_after):
torch.manual_seed(123)
if cuda:
torch.backends.cudnn.deterministic = True
Expand All @@ -239,6 +240,7 @@ def run():
layers,
bidirectional=bidirectional,
layer_norm=layer_norm,
normalize_after=normalize_after,
nn_rnn_compatible_return=compat,
)
words_embeddings = torch.rand(
Expand Down Expand Up @@ -293,7 +295,8 @@ def cell_to_emb(cell, batch_size):
)
@pytest.mark.parametrize("bidirectional", [False, True])
@pytest.mark.parametrize("layer_norm", [False, True])
def test_srupp_backward_simple(cuda, bidirectional, layer_norm):
@pytest.mark.parametrize("normalize_after", [False, True])
def test_srupp_backward_simple(cuda, bidirectional, layer_norm, normalize_after):
torch.manual_seed(123)
if cuda:
torch.backends.cudnn.deterministic = True
Expand All @@ -306,7 +309,8 @@ def test_srupp_backward_simple(cuda, bidirectional, layer_norm):
proj_size = 2
encoder = sru.SRUpp(input_size, hidden_size, proj_size,
bidirectional=bidirectional,
layer_norm=layer_norm)
layer_norm=layer_norm,
normalize_after=normalize_after)
if cuda:
encoder = encoder.cuda()

Expand All @@ -324,7 +328,8 @@ def run(x):
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize("bidirectional", [False, True])
@pytest.mark.parametrize("layer_norm", [False, True])
def test_srupp_backward(bidirectional, layer_norm):
@pytest.mark.parametrize("normalize_after", [False, True])
def test_srupp_backward(bidirectional, layer_norm, normalize_after):
eps = 1e-4
torch.manual_seed(123)
torch.backends.cudnn.deterministic = True
Expand All @@ -337,7 +342,8 @@ def test_srupp_backward(bidirectional, layer_norm):
proj_size = 2
encoder = sru.SRUpp(input_size, hidden_size, proj_size,
bidirectional=bidirectional,
layer_norm=layer_norm)
layer_norm=layer_norm,
normalize_after=normalize_after)
x = torch.randn(input_length, batch_size, input_size)

# backward in CPU mode
Expand Down
6 changes: 3 additions & 3 deletions test/sru/test_torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ def test_sru(cuda, bidirectional, rescale, proj, layer_norm):
)
@pytest.mark.parametrize("bidirectional", [False, True])
@pytest.mark.parametrize("layer_norm", [False, True])
@pytest.mark.parametrize("normalization_type", [1, 2, 3])
@pytest.mark.parametrize("normalize_after", [False, True])
@pytest.mark.parametrize("attn_every_n_layers", [1, 2])
def test_srupp(cuda, bidirectional, layer_norm, normalization_type, attn_every_n_layers):
def test_srupp(cuda, bidirectional, layer_norm, normalize_after, attn_every_n_layers):
eps = 1e-4
torch.manual_seed(1234)
L = 5
Expand All @@ -79,7 +79,7 @@ def test_srupp(cuda, bidirectional, layer_norm, normalization_type, attn_every_n
model = sru.SRUpp(D, D, proj,
bidirectional=bidirectional,
layer_norm=layer_norm,
normalization_type=normalization_type,
normalize_after=normalize_after,
attention_every_n_layers=attn_every_n_layers)
if cuda:
model = model.cuda()
Expand Down
6 changes: 1 addition & 5 deletions test/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@ python test/test_ts_srupp.py > py_srupp_out.txt
sru/csrc/build/example_app srupp_ts.pt > cpp_srupp_out.txt
diff cpp_srupp_out.txt py_srupp_out.txt

python test/test_ts_srupp.py --normalization_type 2 > py_srupp_out.txt
sru/csrc/build/example_app srupp_ts.pt > cpp_srupp_out.txt
diff cpp_srupp_out.txt py_srupp_out.txt

python test/test_ts_srupp.py --normalization_type 3 > py_srupp_out.txt
python test/test_ts_srupp.py --normalize-after > py_srupp_out.txt
sru/csrc/build/example_app srupp_ts.pt > cpp_srupp_out.txt
diff cpp_srupp_out.txt py_srupp_out.txt
4 changes: 2 additions & 2 deletions test/test_ts_srupp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

def run(args):
D = 4
model = sru.SRUpp(D, D, D, num_layers=2, normalization_type=args.normalization_type)
model = sru.SRUpp(D, D, D, num_layers=2, normalize_after=args.normalize_after)
model.eval()

ts_model = torch.jit.script(model)
Expand All @@ -21,6 +21,6 @@ def run(args):

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--normalization_type', type=int, default=1)
parser.add_argument('--normalize-after', action='store_true')
args = parser.parse_args()
run(args)