Skip to content

Commit

Permalink
Add head_mask/decoder_head_mask for TF BART models (#9639)
Browse files Browse the repository at this point in the history
* Add head_mask/decoder_head_mask for TF BART models

* Add head_mask and decoder_head_mask input arguments for TF BART-based
models as a TF counterpart to the PR #9569

* Add test_headmasking functionality to tests/test_modeling_tf_common.py

* TODO: Add a test to verify that we can get a gradient back for
importance score computation

* Remove redundant #TODO note

Remove redundant #TODO note from tests/test_modeling_tf_common.py

* Fix assertions

* Make style

* Fix ...Model input args and adjust one new test

* Add back head_mask and decoder_head_mask to BART-based ...Model
after the last commit

* Remove head_mask ande decoder_head_mask from input_dict
in TF test_train_pipeline_custom_model as these two have different
shape than other input args (Necessary for passing this test)

* Revert adding global_rng in test_modeling_tf_common.py
  • Loading branch information
stancld authored Jan 26, 2021
1 parent cb73ab5 commit 1867d9a
Show file tree
Hide file tree
Showing 32 changed files with 849 additions and 36 deletions.
119 changes: 114 additions & 5 deletions src/transformers/models/bart/modeling_tf_bart.py

Large diffs are not rendered by default.

120 changes: 115 additions & 5 deletions src/transformers/models/blenderbot/modeling_tf_blenderbot.py

Large diffs are not rendered by default.

Large diffs are not rendered by default.

120 changes: 115 additions & 5 deletions src/transformers/models/marian/modeling_tf_marian.py

Large diffs are not rendered by default.

120 changes: 115 additions & 5 deletions src/transformers/models/mbart/modeling_tf_mbart.py

Large diffs are not rendered by default.

120 changes: 115 additions & 5 deletions src/transformers/models/pegasus/modeling_tf_pegasus.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions tests/test_modeling_tf_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available()
else ()
)
test_head_masking = False

def setUp(self):
self.model_tester = TFAlbertModelTester(self)
Expand Down
12 changes: 11 additions & 1 deletion tests/test_modeling_tf_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,11 @@ def check_decoder_model_past_large_inputs(self, config, inputs_dict):

input_ids = input_ids[:1, :]
attention_mask = inputs_dict["attention_mask"][:1, :]
head_mask = inputs_dict["head_mask"]
self.batch_size = 1

# first forward pass
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)

output, past_key_values = outputs.to_tuple()
past_key_values = past_key_values[1]
Expand Down Expand Up @@ -144,6 +145,8 @@ def prepare_bart_inputs_dict(
decoder_input_ids,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
):
if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
Expand All @@ -155,11 +158,17 @@ def prepare_bart_inputs_dict(
],
axis=-1,
)
if head_mask is None:
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None:
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
return {
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": head_mask,
}


Expand All @@ -169,6 +178,7 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFBartForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
test_head_masking = True

def setUp(self):
self.model_tester = TFBartModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available()
else ()
)
test_head_masking = False

# special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
Expand Down
12 changes: 11 additions & 1 deletion tests/test_modeling_tf_blenderbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,11 @@ def check_decoder_model_past_large_inputs(self, config, inputs_dict):

input_ids = input_ids[:1, :]
attention_mask = inputs_dict["attention_mask"][:1, :]
head_mask = inputs_dict["head_mask"]
self.batch_size = 1

# first forward pass
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)

output, past_key_values = outputs.to_tuple()
past_key_values = past_key_values[1]
Expand Down Expand Up @@ -143,6 +144,8 @@ def prepare_blenderbot_inputs_dict(
decoder_input_ids,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
):
if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
Expand All @@ -154,11 +157,17 @@ def prepare_blenderbot_inputs_dict(
],
axis=-1,
)
if head_mask is None:
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None:
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
return {
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
}


Expand All @@ -168,6 +177,7 @@ class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFBlenderbotForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
test_head_masking = True

def setUp(self):
self.model_tester = TFBlenderbotModelTester(self)
Expand Down
12 changes: 11 additions & 1 deletion tests/test_modeling_tf_blenderbot_small.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,11 @@ def check_decoder_model_past_large_inputs(self, config, inputs_dict):

input_ids = input_ids[:1, :]
attention_mask = inputs_dict["attention_mask"][:1, :]
head_mask = inputs_dict["head_mask"]
self.batch_size = 1

# first forward pass
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)

output, past_key_values = outputs.to_tuple()
past_key_values = past_key_values[1]
Expand Down Expand Up @@ -143,6 +144,8 @@ def prepare_blenderbot_small_inputs_dict(
decoder_input_ids,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
):
if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
Expand All @@ -154,11 +157,17 @@ def prepare_blenderbot_small_inputs_dict(
],
axis=-1,
)
if head_mask is None:
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None:
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
return {
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
}


Expand All @@ -170,6 +179,7 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFBlenderbotSmallForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
test_head_masking = True

def setUp(self):
self.model_tester = TFBlenderbotSmallModelTester(self)
Expand Down
74 changes: 74 additions & 0 deletions tests/test_modeling_tf_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,11 @@ def test_pt_tf_model_equivalence(self):

def test_train_pipeline_custom_model(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# head_mask and decoder_head_mask has different shapes than other input args
if "head_mask" in inputs_dict:
del inputs_dict["head_mask"]
if "decoder_head_mask" in inputs_dict:
del inputs_dict["decoder_head_mask"]
tf_main_layer_classes = set(
module_member
for model_class in self.all_model_classes
Expand Down Expand Up @@ -620,6 +625,75 @@ def check_encoder_attentions_output(outputs):
self.assertEqual(model.config.output_hidden_states, True)
check_encoder_attentions_output(outputs)

def test_headmasking(self):
if not self.test_head_masking:
return

random.Random().seed(42)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
random.Random().seed()

inputs_dict["output_attentions"] = True
config.output_hidden_states = True
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)

# Prepare head_mask
def prepare_layer_head_mask(i, attention_heads, num_hidden_layers):
if i == 0:
return tf.concat(
(tf.zeros(1, dtype=tf.float32), tf.ones(attention_heads - 1, dtype=tf.float32)), 0
)
elif i == num_hidden_layers - 1:
return tf.concat(
(tf.zeros(attention_heads - 1, dtype=tf.float32), tf.ones(1, dtype=tf.float32)), 0
)
else:
return tf.ones(attention_heads, dtype=tf.float32)

head_mask = tf.stack(
[
prepare_layer_head_mask(i, config.num_attention_heads, config.num_hidden_layers)
for i in range(config.num_hidden_layers)
],
0,
)

inputs = self._prepare_for_class(inputs_dict, model_class).copy()
inputs["head_mask"] = head_mask
if model.config.is_encoder_decoder:
signature = inspect.signature(model.call)
arg_names = [*signature.parameters.keys()]
if "decoder_head_mask" in arg_names: # necessary diferentiation because of T5 model
inputs["decoder_head_mask"] = head_mask

outputs = model(**inputs, return_dict=True)

def check_attentions_validity(attentions):
# Remove Nan
for t in attentions:
self.assertLess(
(tf.math.reduce_sum(tf.cast(tf.math.is_nan(t), tf.float32))).numpy(), (tf.size(t) / 4).numpy()
) # Check we don't have more than 25% nans (arbitrary)

attentions = [
tf.where(tf.math.is_nan(t), 0.0, t) for t in attentions
] # remove them (the test is less complete)

self.assertAlmostEqual(tf.math.reduce_sum(attentions[0][..., 0, :, :]).numpy(), 0.0)
self.assertNotEqual(tf.math.reduce_sum(attentions[0][..., -1, :, :]).numpy(), 0.0)
if len(attentions) > 2: # encoder-decodere models have only 2 layers in each modules
self.assertNotEqual(tf.math.reduce_sum(attentions[1][..., 0, :, :]).numpy(), 0.0)
self.assertAlmostEqual(tf.math.reduce_sum(attentions[-1][..., -2, :, :]).numpy(), 0.0)
self.assertNotEqual(tf.math.reduce_sum(attentions[-1][..., -1, :, :]).numpy(), 0.0)

if model.config.is_encoder_decoder:
check_attentions_validity(outputs.encoder_attentions)
check_attentions_validity(outputs.decoder_attentions)
else:
check_attentions_validity(outputs.attentions)

def test_hidden_states_output(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_ctrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ class TFCTRLModelTest(TFModelTesterMixin, unittest.TestCase):

all_model_classes = (TFCTRLModel, TFCTRLLMHeadModel, TFCTRLForSequenceClassification) if is_tf_available() else ()
all_generative_model_classes = (TFCTRLLMHeadModel,) if is_tf_available() else ()
test_head_masking = False

def setUp(self):
self.model_tester = TFCTRLModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ class TFDistilBertModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available()
else None
)
test_head_masking = False

def setUp(self):
self.model_tester = TFDistilBertModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ class TFElectraModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available()
else ()
)
test_head_masking = False

def setUp(self):
self.model_tester = TFElectraModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_flaubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ class TFFlaubertModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (
(TFFlaubertWithLMHeadModel,) if is_tf_available() else ()
) # TODO (PVP): Check other models whether language generation is also applicable
test_head_masking = False

def setUp(self):
self.model_tester = TFFlaubertModelTester(self)
Expand Down
2 changes: 2 additions & 0 deletions tests/test_modeling_tf_funnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ class TFFunnelModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available()
else ()
)
test_head_masking = False

def setUp(self):
self.model_tester = TFFunnelModelTester(self)
Expand Down Expand Up @@ -376,6 +377,7 @@ class TFFunnelBaseModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (
(TFFunnelBaseModel, TFFunnelForMultipleChoice, TFFunnelForSequenceClassification) if is_tf_available() else ()
)
test_head_masking = False

def setUp(self):
self.model_tester = TFFunnelModelTester(self, base=True)
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
else ()
)
all_generative_model_classes = (TFGPT2LMHeadModel,) if is_tf_available() else ()
test_head_masking = False

def setUp(self):
self.model_tester = TFGPT2ModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_led.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFLEDForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
test_head_masking = False

def setUp(self):
self.model_tester = TFLEDModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_longformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available()
else ()
)
test_head_masking = False

def setUp(self):
self.model_tester = TFLongformerModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_tf_lxmert.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ def create_and_check_lxmert_for_pretraining(
class TFLxmertModelTest(TFModelTesterMixin, unittest.TestCase):

all_model_classes = (TFLxmertModel, TFLxmertForPreTraining) if is_tf_available() else ()
test_head_masking = False

def setUp(self):
self.model_tester = TFLxmertModelTester(self)
Expand Down
12 changes: 11 additions & 1 deletion tests/test_modeling_tf_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,11 @@ def check_decoder_model_past_large_inputs(self, config, inputs_dict):

input_ids = input_ids[:1, :]
attention_mask = inputs_dict["attention_mask"][:1, :]
head_mask = inputs_dict["head_mask"]
self.batch_size = 1

# first forward pass
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)

output, past_key_values = outputs.to_tuple()
past_key_values = past_key_values[1]
Expand Down Expand Up @@ -145,6 +146,8 @@ def prepare_marian_inputs_dict(
decoder_input_ids,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
):
if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
Expand All @@ -156,11 +159,17 @@ def prepare_marian_inputs_dict(
],
axis=-1,
)
if head_mask is None:
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None:
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
return {
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
}


Expand All @@ -170,6 +179,7 @@ class TFMarianModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFMarianMTModel,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
test_head_masking = True

def setUp(self):
self.model_tester = TFMarianModelTester(self)
Expand Down
Loading

0 comments on commit 1867d9a

Please sign in to comment.