-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add head_mask/decoder_head_mask for TF BART models #9639
Changes from all commits
d3414e6
ef34c08
f419c92
d330880
9af6eb1
e6b7861
6ab64bf
e5fc853
0ba22f0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -440,6 +440,11 @@ def test_pt_tf_model_equivalence(self): | |
|
||
def test_train_pipeline_custom_model(self): | ||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() | ||
# head_mask and decoder_head_mask has different shapes than other input args | ||
if "head_mask" in inputs_dict: | ||
del inputs_dict["head_mask"] | ||
if "decoder_head_mask" in inputs_dict: | ||
del inputs_dict["decoder_head_mask"] | ||
tf_main_layer_classes = set( | ||
module_member | ||
for model_class in self.all_model_classes | ||
|
@@ -620,6 +625,75 @@ def check_encoder_attentions_output(outputs): | |
self.assertEqual(model.config.output_hidden_states, True) | ||
check_encoder_attentions_output(outputs) | ||
|
||
def test_headmasking(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As this test is only for few specific models, I would better see this test inside their respective test files and remove all the occurences of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd rather have that in a common file with |
||
if not self.test_head_masking: | ||
return | ||
|
||
random.Random().seed(42) | ||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() | ||
random.Random().seed() | ||
|
||
inputs_dict["output_attentions"] = True | ||
config.output_hidden_states = True | ||
configs_no_init = _config_zero_init(config) # To be sure we have no Nan | ||
for model_class in self.all_model_classes: | ||
model = model_class(config=configs_no_init) | ||
|
||
# Prepare head_mask | ||
def prepare_layer_head_mask(i, attention_heads, num_hidden_layers): | ||
if i == 0: | ||
return tf.concat( | ||
(tf.zeros(1, dtype=tf.float32), tf.ones(attention_heads - 1, dtype=tf.float32)), 0 | ||
) | ||
elif i == num_hidden_layers - 1: | ||
return tf.concat( | ||
(tf.zeros(attention_heads - 1, dtype=tf.float32), tf.ones(1, dtype=tf.float32)), 0 | ||
) | ||
else: | ||
return tf.ones(attention_heads, dtype=tf.float32) | ||
|
||
head_mask = tf.stack( | ||
[ | ||
prepare_layer_head_mask(i, config.num_attention_heads, config.num_hidden_layers) | ||
for i in range(config.num_hidden_layers) | ||
], | ||
0, | ||
) | ||
|
||
inputs = self._prepare_for_class(inputs_dict, model_class).copy() | ||
inputs["head_mask"] = head_mask | ||
if model.config.is_encoder_decoder: | ||
signature = inspect.signature(model.call) | ||
arg_names = [*signature.parameters.keys()] | ||
if "decoder_head_mask" in arg_names: # necessary diferentiation because of T5 model | ||
inputs["decoder_head_mask"] = head_mask | ||
Comment on lines
+667
to
+669
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice! |
||
|
||
outputs = model(**inputs, return_dict=True) | ||
|
||
def check_attentions_validity(attentions): | ||
# Remove Nan | ||
for t in attentions: | ||
self.assertLess( | ||
(tf.math.reduce_sum(tf.cast(tf.math.is_nan(t), tf.float32))).numpy(), (tf.size(t) / 4).numpy() | ||
) # Check we don't have more than 25% nans (arbitrary) | ||
|
||
attentions = [ | ||
tf.where(tf.math.is_nan(t), 0.0, t) for t in attentions | ||
] # remove them (the test is less complete) | ||
|
||
self.assertAlmostEqual(tf.math.reduce_sum(attentions[0][..., 0, :, :]).numpy(), 0.0) | ||
self.assertNotEqual(tf.math.reduce_sum(attentions[0][..., -1, :, :]).numpy(), 0.0) | ||
if len(attentions) > 2: # encoder-decodere models have only 2 layers in each modules | ||
self.assertNotEqual(tf.math.reduce_sum(attentions[1][..., 0, :, :]).numpy(), 0.0) | ||
self.assertAlmostEqual(tf.math.reduce_sum(attentions[-1][..., -2, :, :]).numpy(), 0.0) | ||
self.assertNotEqual(tf.math.reduce_sum(attentions[-1][..., -1, :, :]).numpy(), 0.0) | ||
|
||
if model.config.is_encoder_decoder: | ||
check_attentions_validity(outputs.encoder_attentions) | ||
check_attentions_validity(outputs.decoder_attentions) | ||
else: | ||
check_attentions_validity(outputs.attentions) | ||
|
||
def test_hidden_states_output(self): | ||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why removing them? We are not supposed to possibly train a Seq2Seq model with these arguments?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jplu - Currently, I'm not fully sure how to implement this test with
head_mask
anddecoder_head_mask
, as they have different shapes thaninput_ids, attention_mask,...
. While the latter ones have shapes of(batch_size, seq_len)
, thehead_mask
has(num_layers, num_attention_heads)
. This results in the following errorDo you have any idea of how to overcome this problem?
Furthermore, at this moment, this test does not consider
head_mask
for any BERT-like models becausehead_mask
is not tested for these models at all. I'm definitely willing to implement such testing for other than BART-like models as well, once we're sure about the proper implementation.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, indeed in a tf dataset all the inputs must share the same shape. To overcome this you should create your own
test_train_pipeline_custom_model
for each Seq2Seq model and replace the tf dataset by the dictionary.Anyway, you should not modify a common test just to make pass the test for a single model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be honest, I don't really agree here @jplu . IMO it's fine to remove
head_mask
from theinputs_dict
. We remove thehead_mask
for other tests as well, such astransformers/tests/test_modeling_common.py
Line 521 in fac7cfb
transformers/tests/test_modeling_common.py
Line 1076 in fac7cfb
Also the
test_train_pipeline_custom_model
is not meant to test the head masking and removing just means that the test covers the exact same functionalities as before. All "normal" head-masking tests were enabled in this PR.I'm fine with leaving it as it for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Humm, ok so that means you should not be supposed to train BART with such arguments, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Contemporary research does not consider training models with
head_mask
. The layer masking is applied during inference, once the model is trained with all attention heads.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, perfect then! I'm fine to let it like this :)