Skip to content

Commit

Permalink
Fix label datatype in TF Trainer (#9616)
Browse files Browse the repository at this point in the history
* Fix label datatype

* Apply style
  • Loading branch information
jplu authored Jan 20, 2021
1 parent 76f36e1 commit 12f0d7e
Showing 1 changed file with 23 additions and 4 deletions.
27 changes: 23 additions & 4 deletions src/transformers/trainer_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,15 @@ def apply_gradients(self, features, labels, nb_instances_in_global_batch):
reduced_features = {
k: ft[: self.args.train_batch_size // self.args.n_replicas] for k, ft in features.items()
}
reduced_labels = labels[: self.args.train_batch_size // self.args.n_replicas]

if tf.is_tensor(labels):
reduced_labels = labels[: self.args.train_batch_size // self.args.n_replicas]
elif isinstance(labels, dict):
reduced_labels = {
k: lbl[: self.args.train_batch_size // self.args.n_replicas] for k, lbl in labels.items()
}
else:
raise ValueError("The labels must be either a tf.Tensor or a dict.")

self.training_step(reduced_features, reduced_labels, nb_instances_in_global_batch)

Expand All @@ -650,9 +658,20 @@ def apply_gradients(self, features, labels, nb_instances_in_global_batch):
for k, ft in features.items()
}

labels = tf.concat(
[labels[self.args.train_batch_size // self.args.n_replicas :], reduced_labels], axis=0
)
if tf.is_tensor(labels):
labels = tf.concat(
[labels[self.args.train_batch_size // self.args.n_replicas :], reduced_labels], axis=0
)
elif isinstance(labels, dict):
labels = {
k: tf.concat(
[lbl[self.args.train_batch_size // self.args.n_replicas :], reduced_labels[k]],
axis=0,
)
for k, lbl in labels.items()
}
else:
raise ValueError("The labels must be either a tf.Tensor or a dict.")

gradients = self.gradient_accumulator.gradients
gradients = [
Expand Down

0 comments on commit 12f0d7e

Please sign in to comment.