Skip to content

Commit

Permalink
Merge pull request asappresearch#191 from asappresearch/remove_rezero
Browse files Browse the repository at this point in the history
Update layer norm options
  • Loading branch information
taoleicn authored Jun 17, 2021
2 parents 85936d7 + 6a4ea07 commit 2a344d3
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 36 deletions.
35 changes: 15 additions & 20 deletions sru/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,11 +350,13 @@ def extra_repr(self):
s += ", activation={activation}"
if self.v1:
s += ", v1={v1}"
s += ", rescale={rescale}"
if self.rescale:
s += ", rescale={rescale}"
if not self.has_skip_term:
s += ", has_skip_term={has_skip_term}"
if self.layer_norm:
s += ", layer_norm=True"
s += ", normalize_after={normalize_after}"
s += ",\n transform_module=" + str(self.transform_module)
return s.format(**self.__dict__)

Expand Down Expand Up @@ -958,7 +960,8 @@ def __init__(self,
num_heads: int = 1,
bidirectional: bool = False,
layer_norm: bool = False,
normalization_type: int = 1,
normalize_after: bool = False,
attn_layer_norm: bool = True,
highway_bias: float = -2.0,
attention_every_n_layers: int = 1,
attention_last_n_layers: int = -1,
Expand Down Expand Up @@ -988,10 +991,12 @@ def __init__(self,
if True, use bidirectional SRU++ (default=False).
layer_norm: bool, optional
whether to apply layer normalization to each SRU++ layer (default=False).
normalization_type: int, optional
which type of layer normalization to apply. 1: apply normalization after attention.
2: apply normalization before any operators. 3: apply normalization after all operators.
(default=1)
normalize_after: bool, optional
whether to apply post layer norm that normalizes the output of each SRU++ layer
(default=False).
attn_layer_norm: bool, optional
whether to apply layer norm in the attention module or projected linear module if
attention is disabled (default=True).
highway_bias: float, optional
the initial value of the bias used in the highway (sigmoid) gate (default=-1.0).
attention_every_n_layers: int, optional
Expand All @@ -1007,10 +1012,6 @@ def __init__(self,
(default=1.0)
"""
if normalization_type > 3 or normalization_type < 1:
raise ValueError("normalization_type={} but expect 1, 2 or 3.".format(
normalization_type
))
if attention_every_n_layers != 1 and attention_last_n_layers != -1:
raise ValueError("Cannot set both attention_every_n_layers and "
"attention_last_n_layers in SRU++ module.")
Expand All @@ -1025,7 +1026,6 @@ def __init__(self,
self.rnn_lst = nn.ModuleList()
self.bidirectional = bidirectional
self.use_layer_norm = layer_norm
self.normalization_type = normalization_type
self.num_directions = 2 if bidirectional else 1
self.nn_rnn_compatible_return = nn_rnn_compatible_return
self.input_to_hidden: Optional[nn.Module] = None
Expand All @@ -1036,11 +1036,6 @@ def __init__(self,
else:
first_layer_input_size = input_size

# layer norm configuration
module_layer_norm = normalization_type == 1
cell_layer_norm = normalization_type != 1
cell_normalize_after = normalization_type == 3

# attention configuration
if attention_last_n_layers != -1:
use_attention = lambda ind: num_layers - ind <= attention_last_n_layers # noqa
Expand All @@ -1061,23 +1056,23 @@ def __init__(self,
dropout=dropout,
attn_dropout=attn_dropout,
num_heads=num_heads,
layer_norm=module_layer_norm,
layer_norm=attn_layer_norm,
)
else:
custom_m = SRUppProjectedLinear(
in_features,
out_features,
proj_features,
dropout=dropout,
layer_norm=module_layer_norm,
layer_norm=attn_layer_norm,
)
layer = SRUppCell(
in_features,
self.hidden_size,
dropout=dropout if i + 1 != num_layers else 0,
bidirectional=bidirectional,
layer_norm=cell_layer_norm,
normalize_after=cell_normalize_after,
layer_norm=layer_norm,
normalize_after=normalize_after,
highway_bias=highway_bias,
rescale=rescale,
transform_module=custom_m,
Expand Down
2 changes: 1 addition & 1 deletion sru/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '3.0.0.dev3'
__version__ = '3.0.0.dev6'
16 changes: 11 additions & 5 deletions test/sru/test_sru.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ def test_srupp_creation(attn_every_n_layers, expected_transform_module):
@pytest.mark.parametrize("compat", [False, True])
@pytest.mark.parametrize("bidirectional", [False, True])
@pytest.mark.parametrize("layer_norm", [False, True])
def test_srupp(cuda, with_grad, compat, bidirectional, layer_norm):
@pytest.mark.parametrize("normalize_after", [False, True])
def test_srupp(cuda, with_grad, compat, bidirectional, layer_norm, normalize_after):
torch.manual_seed(123)
if cuda:
torch.backends.cudnn.deterministic = True
Expand All @@ -239,6 +240,7 @@ def run():
layers,
bidirectional=bidirectional,
layer_norm=layer_norm,
normalize_after=normalize_after,
nn_rnn_compatible_return=compat,
)
words_embeddings = torch.rand(
Expand Down Expand Up @@ -293,7 +295,8 @@ def cell_to_emb(cell, batch_size):
)
@pytest.mark.parametrize("bidirectional", [False, True])
@pytest.mark.parametrize("layer_norm", [False, True])
def test_srupp_backward_simple(cuda, bidirectional, layer_norm):
@pytest.mark.parametrize("normalize_after", [False, True])
def test_srupp_backward_simple(cuda, bidirectional, layer_norm, normalize_after):
torch.manual_seed(123)
if cuda:
torch.backends.cudnn.deterministic = True
Expand All @@ -306,7 +309,8 @@ def test_srupp_backward_simple(cuda, bidirectional, layer_norm):
proj_size = 2
encoder = sru.SRUpp(input_size, hidden_size, proj_size,
bidirectional=bidirectional,
layer_norm=layer_norm)
layer_norm=layer_norm,
normalize_after=normalize_after)
if cuda:
encoder = encoder.cuda()

Expand All @@ -324,7 +328,8 @@ def run(x):
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize("bidirectional", [False, True])
@pytest.mark.parametrize("layer_norm", [False, True])
def test_srupp_backward(bidirectional, layer_norm):
@pytest.mark.parametrize("normalize_after", [False, True])
def test_srupp_backward(bidirectional, layer_norm, normalize_after):
eps = 1e-4
torch.manual_seed(123)
torch.backends.cudnn.deterministic = True
Expand All @@ -337,7 +342,8 @@ def test_srupp_backward(bidirectional, layer_norm):
proj_size = 2
encoder = sru.SRUpp(input_size, hidden_size, proj_size,
bidirectional=bidirectional,
layer_norm=layer_norm)
layer_norm=layer_norm,
normalize_after=normalize_after)
x = torch.randn(input_length, batch_size, input_size)

# backward in CPU mode
Expand Down
6 changes: 3 additions & 3 deletions test/sru/test_torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ def test_sru(cuda, bidirectional, rescale, proj, layer_norm):
)
@pytest.mark.parametrize("bidirectional", [False, True])
@pytest.mark.parametrize("layer_norm", [False, True])
@pytest.mark.parametrize("normalization_type", [1, 2, 3])
@pytest.mark.parametrize("normalize_after", [False, True])
@pytest.mark.parametrize("attn_every_n_layers", [1, 2])
def test_srupp(cuda, bidirectional, layer_norm, normalization_type, attn_every_n_layers):
def test_srupp(cuda, bidirectional, layer_norm, normalize_after, attn_every_n_layers):
eps = 1e-4
torch.manual_seed(1234)
L = 5
Expand All @@ -79,7 +79,7 @@ def test_srupp(cuda, bidirectional, layer_norm, normalization_type, attn_every_n
model = sru.SRUpp(D, D, proj,
bidirectional=bidirectional,
layer_norm=layer_norm,
normalization_type=normalization_type,
normalize_after=normalize_after,
attention_every_n_layers=attn_every_n_layers)
if cuda:
model = model.cuda()
Expand Down
6 changes: 1 addition & 5 deletions test/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@ python test/test_ts_srupp.py > py_srupp_out.txt
sru/csrc/build/example_app srupp_ts.pt > cpp_srupp_out.txt
diff cpp_srupp_out.txt py_srupp_out.txt

python test/test_ts_srupp.py --normalization_type 2 > py_srupp_out.txt
sru/csrc/build/example_app srupp_ts.pt > cpp_srupp_out.txt
diff cpp_srupp_out.txt py_srupp_out.txt

python test/test_ts_srupp.py --normalization_type 3 > py_srupp_out.txt
python test/test_ts_srupp.py --normalize-after > py_srupp_out.txt
sru/csrc/build/example_app srupp_ts.pt > cpp_srupp_out.txt
diff cpp_srupp_out.txt py_srupp_out.txt
4 changes: 2 additions & 2 deletions test/test_ts_srupp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

def run(args):
D = 4
model = sru.SRUpp(D, D, D, num_layers=2, normalization_type=args.normalization_type)
model = sru.SRUpp(D, D, D, num_layers=2, normalize_after=args.normalize_after)
model.eval()

ts_model = torch.jit.script(model)
Expand All @@ -21,6 +21,6 @@ def run(args):

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--normalization_type', type=int, default=1)
parser.add_argument('--normalize-after', action='store_true')
args = parser.parse_args()
run(args)

0 comments on commit 2a344d3

Please sign in to comment.