-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdata.py
128 lines (102 loc) · 4.25 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import pytorch_lightning as pl
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
def convert_to_features(example_batch, indices, tokenizer, text_fields, padding, truncation, max_length):
# Either encode single sentence or sentence pairs
if len(text_fields) > 1:
texts_or_text_pairs = list(zip(example_batch[text_fields[0]], example_batch[text_fields[1]]))
else:
texts_or_text_pairs = example_batch[text_fields[0]]
# Tokenize the text/text pairs
features = tokenizer.batch_encode_plus(
texts_or_text_pairs, padding=padding, truncation=truncation, max_length=max_length
)
# idx is unique ID we can use to link predictions to original data
features['idx'] = indices
return features
def preprocess(ds, tokenizer, text_fields, padding='max_length', truncation='only_first', max_length=128):
ds = ds.map(
convert_to_features,
batched=True,
with_indices=True,
fn_kwargs={
'tokenizer': tokenizer,
'text_fields': text_fields,
'padding': padding,
'truncation': truncation,
'max_length': max_length,
},
)
ds.rename_column_('label', "labels")
return ds
def transform_labels(example, idx, label2id: dict):
str_label = example['labels']
example['labels'] = label2id[str_label]
example['idx'] = idx
return example
class TextClassificationDataModule(pl.LightningDataModule):
def __init__(
self,
model_name_or_path: str = 'bert-base-uncased',
padding: str = 'max_length',
truncation: str = 'only_first',
max_length: int = 128,
batch_size: int = 16,
num_workers: int = 8,
use_fast: bool = True,
seed: int = 42,
):
super().__init__()
self.model_name_or_path = model_name_or_path
self.padding = padding
self.truncation = truncation
self.max_length = max_length
self.batch_size = batch_size
self.num_workers = num_workers
self.use_fast = use_fast
self.seed = seed
def setup(self, stage):
tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=self.use_fast)
self.ds = load_dataset(self.dataset_name, self.subset_name)
if self.train_val_split is not None:
split = self.ds['train'].train_test_split(self.train_val_split) # seed=self.seed
self.ds['train'] = split['train']
self.ds['validation'] = split['test']
self.ds = preprocess(self.ds, tokenizer, self.text_fields, self.padding, self.truncation, self.max_length)
if self.do_transform_labels:
self.ds = self.ds.map(transform_labels, with_indices=True, fn_kwargs={'label2id': self.label2id})
cols_to_keep = [
x
for x in ['input_ids', 'attention_mask', 'token_type_ids', 'labels', 'idx']
if x in self.ds['train'].features
]
self.ds.set_format("torch", columns=cols_to_keep)
self.tokenizer = tokenizer
def train_dataloader(self):
return DataLoader(self.ds['train'], batch_size=self.batch_size, num_workers=self.num_workers)
def val_dataloader(self):
return DataLoader(self.ds['validation'], batch_size=self.batch_size, num_workers=self.num_workers)
def test_dataloader(self):
return DataLoader(self.ds['test'], batch_size=self.batch_size, num_workers=self.num_workers)
class EmotionDataModule(TextClassificationDataModule):
dataset_name = 'emotion'
subset_name = None
text_fields = ['text']
label2id = {"sadness": 0, "joy": 1, "love": 2, "anger": 3, "fear": 4, "surprise": 5}
do_transform_labels = True
train_val_split = None
class AGNewsDataModule(TextClassificationDataModule):
dataset_name = 'ag_news'
subset_name = None
text_fields = ['text']
label2id = {"World": 0, "Sports": 1, "Business": 2, "Sci/Tech": 3}
do_transform_labels = False
train_val_split = 20000
class MrpcDataModule(TextClassificationDataModule):
dataset_name = 'glue'
subset_name = 'mrpc'
text_fields = ['sentence1', 'sentence2']
label2id = {"not_equivalent": 0, "equivalent": 1}
do_transform_labels = False
train_val_split = None