-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add head_mask/decoder_head_mask for TF BART models #9639
Add head_mask/decoder_head_mask for TF BART models #9639
Conversation
* Add head_mask and decoder_head_mask input arguments for TF BART-based models as a TF counterpart to the PR huggingface#9569 * Add test_headmasking functionality to tests/test_modeling_tf_common.py * TODO: Add a test to verify that we can get a gradient back for importance score computation
@stancld, thanks so much for tackling this! I think it would be a great addition if we could add a |
Remove redundant #TODO note from tests/test_modeling_tf_common.py
Hey @patrickvonplaten, I hope this PR is ready for review. There's newly implemented It seems all checks have passed after rebasing this PR. |
tests/test_modeling_tf_common.py
Outdated
@@ -1068,10 +1141,13 @@ def _check_match_tokens(self, generated_ids, bad_words_ids): | |||
return False | |||
|
|||
|
|||
global_rng = random.Random() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure if we want to have a global random seed, was this an intended change or just used for debugging?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, I followed the convention from the corresponding PyTorch testing file test_modeling_common.py
, where global_rng
is also defined for a similar purpose. However, I think it's okay to omit this change and leave it as it was :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would be more in favor to keep it as it was.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jplu I removed global_rng
and leave it as it was before changes. Hopefully, now this PR is ready for a final review
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very clean! Thanks a lot for the contribution.
I'm only wondering about the global random seed, what was the reason behind this change?
Also, @jplu it would be great if you could take a quick look if this is all serving compatible (I don't see a reason why it wouldn't be though) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for this really nice PR!! Before to approve on my side, can you comment these lines:
@slow
def test_saved_model_with_hidden_states_output(self):
pass
@slow
def test_saved_model_with_attentions_output(self):
pass
def test_saved_model_creation(self):
pass
def test_saved_model_creation_extended(self):
pass
And run them with the environment variable RUN_SLOW=1
in order to be sure that everything is ok. You can do it at least just for BART.
If all these tests are passing then I will approve otherwise we will have to fix them.
@@ -618,6 +622,75 @@ def check_encoder_attentions_output(outputs): | |||
self.assertEqual(model.config.output_hidden_states, True) | |||
check_encoder_attentions_output(outputs) | |||
|
|||
def test_headmasking(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As this test is only for few specific models, I would better see this test inside their respective test files and remove all the occurences of test_head_masking = False
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rather have that in a common file with test_head_masking=False
in the other files; that's what we do for tests that are shared, and it gives the incentive of switching these to True
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for working on this @stancld
@@ -618,6 +622,75 @@ def check_encoder_attentions_output(outputs): | |||
self.assertEqual(model.config.output_hidden_states, True) | |||
check_encoder_attentions_output(outputs) | |||
|
|||
def test_headmasking(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rather have that in a common file with test_head_masking=False
in the other files; that's what we do for tests that are shared, and it gives the incentive of switching these to True
.
arg_names = [*signature.parameters.keys()] | ||
if "decoder_head_mask" in arg_names: # necessary diferentiation because of T5 model | ||
inputs["decoder_head_mask"] = head_mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this. There might be a few assert/shapes to rework.
@@ -230,6 +231,15 @@ def call( | |||
|
|||
attn_weights = tf.nn.softmax(attn_weights, axis=-1) | |||
|
|||
if layer_head_mask is not None: | |||
assert layer_head_mask.shape == ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should use shape_list
here. The assert might also need to be wrapped in some tf operation. @jplu will know better than me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh good catch!!
the assert must be also replaced by a tf.debugging.assert
equivalent.
assert inputs["head_mask"].shape[0] == ( | ||
len(self.layers) | ||
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {inputs['head_mask'].shape[0]}." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here (I won't comment on the other modeling files, but obviously it applies to all of them).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed same here
assert inputs["head_mask"].shape[0] == ( | ||
len(self.layers) | ||
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {inputs['head_mask'].shape[0]}." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And same here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And here as well!
Just done further tests on your PR and the changes are not graph compliant and the following slow tests are failing:
One of the reasons is what @sgugger raised. |
* Add back head_mask and decoder_head_mask to BART-based ...Model after the last commit * Remove head_mask ande decoder_head_mask from input_dict in TF test_train_pipeline_custom_model as these two have different shape than other input args (Necessary for passing this test)
Hi @jplu, could you, please, review the changes in the code I've done to say whether assertions are done more appropriately now? :) |
I confirm that the assertions are done more appropriately now! Those four tests are among the most important one for the TF code base (they are run in slow mode because unfortunately they take some time to be executed). If you need some help to make them pass, I will be happy to. |
# head_mask and decoder_head_mask has different shapes than other input args | ||
if "head_mask" in inputs_dict: | ||
del inputs_dict["head_mask"] | ||
if "decoder_head_mask" in inputs_dict: | ||
del inputs_dict["decoder_head_mask"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why removing them? We are not supposed to possibly train a Seq2Seq model with these arguments?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jplu - Currently, I'm not fully sure how to implement this test with head_mask
and decoder_head_mask
, as they have different shapes than input_ids, attention_mask,...
. While the latter ones have shapes of (batch_size, seq_len)
, the head_mask
has (num_layers, num_attention_heads)
. This results in the following error
> X = tf.data.Dataset.from_tensor_slices(
(inputs_dict, np.ones((self.model_tester.batch_size, self.model_tester.seq_length, num_labels, 1)))
).batch(1)
----------------------------------
self = Dimension(13), other = Dimension(5)
def assert_is_compatible_with(self, other):
"""Raises an exception if `other` is not compatible with this Dimension.
Args:
other: Another Dimension.
Raises:
ValueError: If `self` and `other` are not compatible (see
is_compatible_with).
"""
if not self.is_compatible_with(other):
> raise ValueError("Dimensions %s and %s are not compatible" %
(self, other))
E ValueError: Dimensions 13 and 5 are not compatible
../../../../../../miniconda3/envs/bart/lib/python3.8/site-packages/tensorflow/python/framework/tensor_shape.py:281: ValueError
Do you have any idea of how to overcome this problem?
Furthermore, at this moment, this test does not consider head_mask
for any BERT-like models because head_mask
is not tested for these models at all. I'm definitely willing to implement such testing for other than BART-like models as well, once we're sure about the proper implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, indeed in a tf dataset all the inputs must share the same shape. To overcome this you should create your own test_train_pipeline_custom_model
for each Seq2Seq model and replace the tf dataset by the dictionary.
Anyway, you should not modify a common test just to make pass the test for a single model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be honest, I don't really agree here @jplu . IMO it's fine to remove head_mask
from the inputs_dict
. We remove the head_mask
for other tests as well, such as
transformers/tests/test_modeling_common.py
Line 521 in fac7cfb
if "head_mask" in inputs_dict: |
transformers/tests/test_modeling_common.py
Line 1076 in fac7cfb
blacklist_non_batched_params = ["head_mask", "decoder_head_mask"] |
Also the test_train_pipeline_custom_model
is not meant to test the head masking and removing just means that the test covers the exact same functionalities as before. All "normal" head-masking tests were enabled in this PR.
I'm fine with leaving it as it for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Humm, ok so that means you should not be supposed to train BART with such arguments, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Contemporary research does not consider training models with head_mask
. The layer masking is applied during inference, once the model is trained with all attention heads.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, perfect then! I'm fine to let it like this :)
Are these tests finally pass? :
If yes, I will approve the PR :) |
@jplu I ran these 4 aforementioned tests for BART and all those tests passed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok then it's perfect!! Great work @stancld !
Merging, thanks a lot for your efforts @stancld!! |
This PR adds
head_mask
anddecoder_head_mask
input arguments for TF BART-based models. The full list of models is as follows:This PR can be deemed as a TF counterpart to the PR #9569.
Further information:
I've added
test_headmasking
functionality totests/test_modeling_tf_common.py
TODO: Add a test (as a part of
test_headmasking
) to verify that we can get a gradient back for importance score computation. I am not so familiar with TensorFlow, therefore, I am not fully sure with a TF equivalent toReviewer: @patrickvonplaten