Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to use datasets remove_cloumns method #11343

Merged
merged 2 commits into from
Apr 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/question-answering/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
datasets >= 1.2.1
datasets >= 1.4.0
13 changes: 1 addition & 12 deletions examples/question-answering/trainer_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,10 @@
A subclass of `Trainer` specific to Question-Answering tasks
"""

from transformers import Trainer, is_datasets_available, is_torch_tpu_available
from transformers import Trainer, is_torch_tpu_available
from transformers.trainer_utils import PredictionOutput


if is_datasets_available():
import datasets

if is_torch_tpu_available():
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
Expand Down Expand Up @@ -54,10 +51,6 @@ def evaluate(self, eval_dataset=None, eval_examples=None, ignore_keys=None):
finally:
self.compute_metrics = compute_metrics

# We might have removed columns from the dataset so we put them back.
if isinstance(eval_dataset, datasets.Dataset):
eval_dataset.set_format(type=eval_dataset.format["type"], columns=list(eval_dataset.features.keys()))

if self.post_process_function is not None and self.compute_metrics is not None:
eval_preds = self.post_process_function(eval_examples, eval_dataset, output.predictions)
metrics = self.compute_metrics(eval_preds)
Expand Down Expand Up @@ -94,10 +87,6 @@ def predict(self, test_dataset, test_examples, ignore_keys=None):
if self.post_process_function is None or self.compute_metrics is None:
return output

# We might have removed columns from the dataset so we put them back.
if isinstance(test_dataset, datasets.Dataset):
test_dataset.set_format(type=test_dataset.format["type"], columns=list(test_dataset.features.keys()))

eval_preds = self.post_process_function(test_examples, test_dataset, output.predictions, "test")
metrics = self.compute_metrics(eval_preds)

Expand Down
33 changes: 19 additions & 14 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,11 +394,6 @@ def __init__(
raise ValueError("train_dataset does not implement __len__, max_steps has to be specified")

self._signature_columns = None
if is_datasets_available():
if isinstance(train_dataset, datasets.Dataset):
self._remove_unused_columns(self.train_dataset, description="training")
if isinstance(eval_dataset, datasets.Dataset):
self._remove_unused_columns(self.eval_dataset, description="evaluation")

# Mixed precision setup
self.use_apex = False
Expand Down Expand Up @@ -503,7 +498,13 @@ def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optio
f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
)

dataset.set_format(type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"])
if version.parse(datasets.__version__) < version.parse("1.4.0"):
dataset.set_format(
type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"]
)
return dataset
else:
return dataset.remove_columns(ignored_columns)

def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
if not isinstance(self.train_dataset, collections.abc.Sized):
Expand Down Expand Up @@ -565,17 +566,20 @@ def get_train_dataloader(self) -> DataLoader:
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")

if isinstance(self.train_dataset, torch.utils.data.dataset.IterableDataset):
train_dataset = self.train_dataset
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
train_dataset = self._remove_unused_columns(train_dataset, description="training")

if isinstance(train_dataset, torch.utils.data.dataset.IterableDataset):
if self.args.world_size > 1:
train_dataset = IterableDatasetShard(
self.train_dataset,
train_dataset,
batch_size=self.args.train_batch_size,
drop_last=self.args.dataloader_drop_last,
num_processes=self.args.world_size,
process_index=self.args.process_index,
)
else:
train_dataset = self.train_dataset

return DataLoader(
train_dataset,
batch_size=self.args.train_batch_size,
Expand All @@ -587,7 +591,7 @@ def get_train_dataloader(self) -> DataLoader:
train_sampler = self._get_train_sampler()

return DataLoader(
self.train_dataset,
train_dataset,
batch_size=self.args.train_batch_size,
sampler=train_sampler,
collate_fn=self.data_collator,
Expand Down Expand Up @@ -638,10 +642,11 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa
"""
if eval_dataset is None and self.eval_dataset is None:
raise ValueError("Trainer: evaluation requires an eval_dataset.")
elif is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
self._remove_unused_columns(eval_dataset, description="evaluation")
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset

if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")

if isinstance(eval_dataset, torch.utils.data.dataset.IterableDataset):
if self.args.world_size > 1:
eval_dataset = IterableDatasetShard(
Expand Down Expand Up @@ -683,7 +688,7 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
``model.forward()`` method are automatically removed. It must implement :obj:`__len__`.
"""
if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
self._remove_unused_columns(test_dataset, description="test")
test_dataset = self._remove_unused_columns(test_dataset, description="test")

if isinstance(test_dataset, torch.utils.data.dataset.IterableDataset):
if self.args.world_size > 1:
Expand Down