diff --git a/src/transformers/trainer_tf.py b/src/transformers/trainer_tf.py index ac75eb6223cd..d6c92ebc0567 100644 --- a/src/transformers/trainer_tf.py +++ b/src/transformers/trainer_tf.py @@ -638,7 +638,15 @@ def apply_gradients(self, features, labels, nb_instances_in_global_batch): reduced_features = { k: ft[: self.args.train_batch_size // self.args.n_replicas] for k, ft in features.items() } - reduced_labels = labels[: self.args.train_batch_size // self.args.n_replicas] + + if tf.is_tensor(labels): + reduced_labels = labels[: self.args.train_batch_size // self.args.n_replicas] + elif isinstance(labels, dict): + reduced_labels = { + k: lbl[: self.args.train_batch_size // self.args.n_replicas] for k, lbl in labels.items() + } + else: + raise ValueError("The labels must be either a tf.Tensor or a dict.") self.training_step(reduced_features, reduced_labels, nb_instances_in_global_batch) @@ -650,9 +658,20 @@ def apply_gradients(self, features, labels, nb_instances_in_global_batch): for k, ft in features.items() } - labels = tf.concat( - [labels[self.args.train_batch_size // self.args.n_replicas :], reduced_labels], axis=0 - ) + if tf.is_tensor(labels): + labels = tf.concat( + [labels[self.args.train_batch_size // self.args.n_replicas :], reduced_labels], axis=0 + ) + elif isinstance(labels, dict): + labels = { + k: tf.concat( + [lbl[self.args.train_batch_size // self.args.n_replicas :], reduced_labels[k]], + axis=0, + ) + for k, lbl in labels.items() + } + else: + raise ValueError("The labels must be either a tf.Tensor or a dict.") gradients = self.gradient_accumulator.gradients gradients = [