Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add head_mask/decoder_head_mask for TF BART models #9639

Merged
merged 9 commits into from
Jan 26, 2021
119 changes: 114 additions & 5 deletions src/transformers/models/bart/modeling_tf_bart.py

Large diffs are not rendered by default.

120 changes: 115 additions & 5 deletions src/transformers/models/blenderbot/modeling_tf_blenderbot.py

Large diffs are not rendered by default.

Large diffs are not rendered by default.

120 changes: 115 additions & 5 deletions src/transformers/models/marian/modeling_tf_marian.py

Large diffs are not rendered by default.

120 changes: 115 additions & 5 deletions src/transformers/models/mbart/modeling_tf_mbart.py

Large diffs are not rendered by default.

120 changes: 115 additions & 5 deletions src/transformers/models/pegasus/modeling_tf_pegasus.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions tests/test_modeling_tf_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available()
else ()
)
test_head_masking = False

def setUp(self):
self.model_tester = TFAlbertModelTester(self)
Expand Down
12 changes: 11 additions & 1 deletion tests/test_modeling_tf_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,11 @@ def check_decoder_model_past_large_inputs(self, config, inputs_dict):

input_ids = input_ids[:1, :]
attention_mask = inputs_dict["attention_mask"][:1, :]
head_mask = inputs_dict["head_mask"]
self.batch_size = 1

# first forward pass
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)

output, past_key_values = outputs.to_tuple()
past_key_values = past_key_values[1]
Expand Down Expand Up @@ -144,6 +145,8 @@ def prepare_bart_inputs_dict(
decoder_input_ids,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
):
if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
Expand All @@ -155,11 +158,17 @@ def prepare_bart_inputs_dict(
],
axis=-1,
)
if head_mask is None:
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None:
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
return {
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": head_mask,
}


Expand All @@ -169,6 +178,7 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFBartForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
test_head_masking = True

def setUp(self):
self.model_tester = TFBartModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available()
else ()
)
test_head_masking = False

# special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
Expand Down
12 changes: 11 additions & 1 deletion tests/test_modeling_tf_blenderbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,11 @@ def check_decoder_model_past_large_inputs(self, config, inputs_dict):

input_ids = input_ids[:1, :]
attention_mask = inputs_dict["attention_mask"][:1, :]
head_mask = inputs_dict["head_mask"]
self.batch_size = 1

# first forward pass
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)

output, past_key_values = outputs.to_tuple()
past_key_values = past_key_values[1]
Expand Down Expand Up @@ -143,6 +144,8 @@ def prepare_blenderbot_inputs_dict(
decoder_input_ids,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
):
if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
Expand All @@ -154,11 +157,17 @@ def prepare_blenderbot_inputs_dict(
],
axis=-1,
)
if head_mask is None:
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None:
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
return {
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
}


Expand All @@ -168,6 +177,7 @@ class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFBlenderbotForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
test_head_masking = True

def setUp(self):
self.model_tester = TFBlenderbotModelTester(self)
Expand Down
12 changes: 11 additions & 1 deletion tests/test_modeling_tf_blenderbot_small.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,11 @@ def check_decoder_model_past_large_inputs(self, config, inputs_dict):

input_ids = input_ids[:1, :]
attention_mask = inputs_dict["attention_mask"][:1, :]
head_mask = inputs_dict["head_mask"]
self.batch_size = 1

# first forward pass
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)

output, past_key_values = outputs.to_tuple()
past_key_values = past_key_values[1]
Expand Down Expand Up @@ -143,6 +144,8 @@ def prepare_blenderbot_small_inputs_dict(
decoder_input_ids,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
):
if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
Expand All @@ -154,11 +157,17 @@ def prepare_blenderbot_small_inputs_dict(
],
axis=-1,
)
if head_mask is None:
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None:
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
return {
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
}


Expand All @@ -170,6 +179,7 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFBlenderbotSmallForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
test_head_masking = True

def setUp(self):
self.model_tester = TFBlenderbotSmallModelTester(self)
Expand Down
74 changes: 74 additions & 0 deletions tests/test_modeling_tf_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,11 @@ def test_pt_tf_model_equivalence(self):

def test_train_pipeline_custom_model(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# head_mask and decoder_head_mask has different shapes than other input args
if "head_mask" in inputs_dict:
del inputs_dict["head_mask"]
if "decoder_head_mask" in inputs_dict:
del inputs_dict["decoder_head_mask"]
Comment on lines +443 to +447
Copy link
Contributor

@jplu jplu Jan 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why removing them? We are not supposed to possibly train a Seq2Seq model with these arguments?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jplu - Currently, I'm not fully sure how to implement this test with head_mask and decoder_head_mask, as they have different shapes than input_ids, attention_mask,.... While the latter ones have shapes of (batch_size, seq_len), the head_mask has (num_layers, num_attention_heads). This results in the following error

> X = tf.data.Dataset.from_tensor_slices(
    (inputs_dict, np.ones((self.model_tester.batch_size, self.model_tester.seq_length, num_labels, 1)))
).batch(1)
----------------------------------
self = Dimension(13), other = Dimension(5)

    def assert_is_compatible_with(self, other):
      """Raises an exception if `other` is not compatible with this Dimension.
    
      Args:
        other: Another Dimension.
    
      Raises:
        ValueError: If `self` and `other` are not compatible (see
          is_compatible_with).
      """
      if not self.is_compatible_with(other):
>       raise ValueError("Dimensions %s and %s are not compatible" %
                         (self, other))
E       ValueError: Dimensions 13 and 5 are not compatible

../../../../../../miniconda3/envs/bart/lib/python3.8/site-packages/tensorflow/python/framework/tensor_shape.py:281: ValueError

Do you have any idea of how to overcome this problem?

Furthermore, at this moment, this test does not consider head_mask for any BERT-like models because head_mask is not tested for these models at all. I'm definitely willing to implement such testing for other than BART-like models as well, once we're sure about the proper implementation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, indeed in a tf dataset all the inputs must share the same shape. To overcome this you should create your own test_train_pipeline_custom_model for each Seq2Seq model and replace the tf dataset by the dictionary.

Anyway, you should not modify a common test just to make pass the test for a single model.

Copy link
Contributor

@patrickvonplaten patrickvonplaten Jan 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be honest, I don't really agree here @jplu . IMO it's fine to remove head_mask from the inputs_dict. We remove the head_mask for other tests as well, such as

if "head_mask" in inputs_dict:
,
blacklist_non_batched_params = ["head_mask", "decoder_head_mask"]
.

Also the test_train_pipeline_custom_model is not meant to test the head masking and removing just means that the test covers the exact same functionalities as before. All "normal" head-masking tests were enabled in this PR.
I'm fine with leaving it as it for now

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Humm, ok so that means you should not be supposed to train BART with such arguments, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Contemporary research does not consider training models with head_mask. The layer masking is applied during inference, once the model is trained with all attention heads.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, perfect then! I'm fine to let it like this :)

tf_main_layer_classes = set(
module_member
for model_class in self.all_model_classes
Expand Down Expand Up @@ -620,6 +625,75 @@ def check_encoder_attentions_output(outputs):
self.assertEqual(model.config.output_hidden_states, True)
check_encoder_attentions_output(outputs)

def test_headmasking(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As this test is only for few specific models, I would better see this test inside their respective test files and remove all the occurences of test_head_masking = False

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather have that in a common file with test_head_masking=False in the other files; that's what we do for tests that are shared, and it gives the incentive of switching these to True.

if not self.test_head_masking:
return

random.Random().seed(42)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
random.Random().seed()

inputs_dict["output_attentions"] = True
config.output_hidden_states = True
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)

# Prepare head_mask
def prepare_layer_head_mask(i, attention_heads, num_hidden_layers):
if i == 0:
return tf.concat(
(tf.zeros(1, dtype=tf.float32), tf.ones(attention_heads - 1, dtype=tf.float32)), 0
)
elif i == num_hidden_layers - 1:
return tf.concat(
(tf.zeros(attention_heads - 1, dtype=tf.float32), tf.ones(1, dtype=tf.float32)), 0
)
else:
return tf.ones(attention_heads, dtype=tf.float32)

head_mask = tf.stack(
[
prepare_layer_head_mask(i, config.num_attention_heads, config.num_hidden_layers)
for i in range(config.num_hidden_layers)
],
0,
)

inputs = self._prepare_for_class(inputs_dict, model_class).copy()
inputs["head_mask"] = head_mask
if model.config.is_encoder_decoder:
signature = inspect.signature(model.call)
arg_names = [*signature.parameters.keys()]
if "decoder_head_mask" in arg_names: # necessary diferentiation because of T5 model
inputs["decoder_head_mask"] = head_mask
Comment on lines +667 to +669
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!


outputs = model(**inputs, return_dict=True)

def check_attentions_validity(attentions):
# Remove Nan
for t in attentions:
self.assertLess(
(tf.math.reduce_sum(tf.cast(tf.math.is_nan(t), tf.float32))).numpy(), (tf.size(t) / 4).numpy()
) # Check we don't have more than 25% nans (arbitrary)

attentions = [
tf.where(tf.math.is_nan(t), 0.0, t) for t in attentions
] # remove them (the test is less complete)

self.assertAlmostEqual(tf.math.reduce_sum(attentions[0][..., 0, :, :]).numpy(), 0.0)
self.assertNotEqual(tf.math.reduce_sum(attentions[0][..., -1, :, :]).numpy(), 0.0)
if len(attentions) > 2: # encoder-decodere models have only 2 layers in each modules
self.assertNotEqual(tf.math.reduce_sum(attentions[1][..., 0, :, :]).numpy(), 0.0)
self.assertAlmostEqual(tf.math.reduce_sum(attentions[-1][..., -2, :, :]).numpy(), 0.0)
self.assertNotEqual(tf.math.reduce_sum(attentions[-1][..., -1, :, :]).numpy(), 0.0)

if model.config.is_encoder_decoder:
check_attentions_validity(outputs.encoder_attentions)
check_attentions_validity(outputs.decoder_attentions)
else:
check_attentions_validity(outputs.attentions)

def test_hidden_states_output(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_ctrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ class TFCTRLModelTest(TFModelTesterMixin, unittest.TestCase):

all_model_classes = (TFCTRLModel, TFCTRLLMHeadModel, TFCTRLForSequenceClassification) if is_tf_available() else ()
all_generative_model_classes = (TFCTRLLMHeadModel,) if is_tf_available() else ()
test_head_masking = False

def setUp(self):
self.model_tester = TFCTRLModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ class TFDistilBertModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available()
else None
)
test_head_masking = False

def setUp(self):
self.model_tester = TFDistilBertModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ class TFElectraModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available()
else ()
)
test_head_masking = False

def setUp(self):
self.model_tester = TFElectraModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_flaubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ class TFFlaubertModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (
(TFFlaubertWithLMHeadModel,) if is_tf_available() else ()
) # TODO (PVP): Check other models whether language generation is also applicable
test_head_masking = False

def setUp(self):
self.model_tester = TFFlaubertModelTester(self)
Expand Down
2 changes: 2 additions & 0 deletions tests/test_modeling_tf_funnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ class TFFunnelModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available()
else ()
)
test_head_masking = False

def setUp(self):
self.model_tester = TFFunnelModelTester(self)
Expand Down Expand Up @@ -376,6 +377,7 @@ class TFFunnelBaseModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (
(TFFunnelBaseModel, TFFunnelForMultipleChoice, TFFunnelForSequenceClassification) if is_tf_available() else ()
)
test_head_masking = False

def setUp(self):
self.model_tester = TFFunnelModelTester(self, base=True)
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
else ()
)
all_generative_model_classes = (TFGPT2LMHeadModel,) if is_tf_available() else ()
test_head_masking = False

def setUp(self):
self.model_tester = TFGPT2ModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_led.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFLEDForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
test_head_masking = False

def setUp(self):
self.model_tester = TFLEDModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_longformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available()
else ()
)
test_head_masking = False

def setUp(self):
self.model_tester = TFLongformerModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_lxmert.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ def create_and_check_lxmert_for_pretraining(
class TFLxmertModelTest(TFModelTesterMixin, unittest.TestCase):

all_model_classes = (TFLxmertModel, TFLxmertForPreTraining) if is_tf_available() else ()
test_head_masking = False

def setUp(self):
self.model_tester = TFLxmertModelTester(self)
Expand Down
12 changes: 11 additions & 1 deletion tests/test_modeling_tf_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,11 @@ def check_decoder_model_past_large_inputs(self, config, inputs_dict):

input_ids = input_ids[:1, :]
attention_mask = inputs_dict["attention_mask"][:1, :]
head_mask = inputs_dict["head_mask"]
self.batch_size = 1

# first forward pass
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)

output, past_key_values = outputs.to_tuple()
past_key_values = past_key_values[1]
Expand Down Expand Up @@ -145,6 +146,8 @@ def prepare_marian_inputs_dict(
decoder_input_ids,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
):
if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
Expand All @@ -156,11 +159,17 @@ def prepare_marian_inputs_dict(
],
axis=-1,
)
if head_mask is None:
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None:
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
return {
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
}


Expand All @@ -170,6 +179,7 @@ class TFMarianModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFMarianMTModel,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
test_head_masking = True

def setUp(self):
self.model_tester = TFMarianModelTester(self)
Expand Down
Loading