From ccea215081e8de4637ba1706906574037c150a80 Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Mon, 18 Jan 2021 15:46:17 +0100 Subject: [PATCH 1/3] Fix Flaubert and XLM --- src/transformers/models/flaubert/modeling_tf_flaubert.py | 9 ++++++--- src/transformers/models/xlm/modeling_tf_xlm.py | 9 ++++++--- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/flaubert/modeling_tf_flaubert.py b/src/transformers/models/flaubert/modeling_tf_flaubert.py index 28ebba7daa40..0b614ab94294 100644 --- a/src/transformers/models/flaubert/modeling_tf_flaubert.py +++ b/src/transformers/models/flaubert/modeling_tf_flaubert.py @@ -214,10 +214,13 @@ def dummy_inputs(self): inputs_list = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]) attns_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]) if self.config.use_lang_emb and self.config.n_langs > 1: - langs_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]) + return { + "input_ids": inputs_list, + "attention_mask": attns_list, + "langs_list": tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), + } else: - langs_list = None - return {"input_ids": inputs_list, "attention_mask": attns_list, "langs": langs_list} + return {"input_ids": inputs_list, "attention_mask": attns_list} @add_start_docstrings( diff --git a/src/transformers/models/xlm/modeling_tf_xlm.py b/src/transformers/models/xlm/modeling_tf_xlm.py index 31d8b4fd49a1..afaa91eb612f 100644 --- a/src/transformers/models/xlm/modeling_tf_xlm.py +++ b/src/transformers/models/xlm/modeling_tf_xlm.py @@ -536,10 +536,13 @@ def dummy_inputs(self): inputs_list = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]) attns_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]) if self.config.use_lang_emb and self.config.n_langs > 1: - langs_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]) + return { + "input_ids": inputs_list, + "attention_mask": attns_list, + "langs_list": tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), + } else: - langs_list = None - return {"input_ids": inputs_list, "attention_mask": attns_list, "langs": langs_list} + return {"input_ids": inputs_list, "attention_mask": attns_list} # Remove when XLMWithLMHead computes loss like other LM models From 3f421833334868b8b6e4c242da1cd775735ad61b Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Mon, 18 Jan 2021 18:01:53 +0100 Subject: [PATCH 2/3] Fix Flaubert and XLM --- .../models/flaubert/modeling_tf_flaubert.py | 2 +- src/transformers/models/xlm/modeling_tf_xlm.py | 17 ++++++++++++----- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/flaubert/modeling_tf_flaubert.py b/src/transformers/models/flaubert/modeling_tf_flaubert.py index 0b614ab94294..f24dfa747380 100644 --- a/src/transformers/models/flaubert/modeling_tf_flaubert.py +++ b/src/transformers/models/flaubert/modeling_tf_flaubert.py @@ -217,7 +217,7 @@ def dummy_inputs(self): return { "input_ids": inputs_list, "attention_mask": attns_list, - "langs_list": tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), + "langs": tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), } else: return {"input_ids": inputs_list, "attention_mask": attns_list} diff --git a/src/transformers/models/xlm/modeling_tf_xlm.py b/src/transformers/models/xlm/modeling_tf_xlm.py index afaa91eb612f..e56be1738ef6 100644 --- a/src/transformers/models/xlm/modeling_tf_xlm.py +++ b/src/transformers/models/xlm/modeling_tf_xlm.py @@ -539,7 +539,7 @@ def dummy_inputs(self): return { "input_ids": inputs_list, "attention_mask": attns_list, - "langs_list": tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), + "langs": tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), } else: return {"input_ids": inputs_list, "attention_mask": attns_list} @@ -1048,10 +1048,17 @@ def dummy_inputs(self): Returns: tf.Tensor with dummy inputs """ - return { - "input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS), - "langs": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS), - } + # Sometimes XLM has language embeddings so don't forget to build them as well if needed + if self.config.use_lang_emb and self.config.n_langs > 1: + return { + "input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS), + "langs": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS), + } + else: + return { + "input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS), + } + @add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) @add_code_sample_docstrings( From 0460520da6df3d0d47b2c9cf71218507f9166a98 Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Mon, 18 Jan 2021 18:05:36 +0100 Subject: [PATCH 3/3] Apply style --- src/transformers/models/xlm/modeling_tf_xlm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/xlm/modeling_tf_xlm.py b/src/transformers/models/xlm/modeling_tf_xlm.py index e56be1738ef6..8cd3c7ef4814 100644 --- a/src/transformers/models/xlm/modeling_tf_xlm.py +++ b/src/transformers/models/xlm/modeling_tf_xlm.py @@ -1058,7 +1058,6 @@ def dummy_inputs(self): return { "input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS), } - @add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) @add_code_sample_docstrings(