This repository has been archived by the owner on Apr 3, 2020. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdataset.py
63 lines (48 loc) · 2.13 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from typing import Mapping
import torch
from torch.utils.data import Dataset
from transformers import DistilBertTokenizer
from bert_ner.catalyst_ext import StateKeys
class KeyphrasesDataset(Dataset):
def __init__(self, texts, keyphrases, keys: StateKeys, max_seq_length=512):
self.texts = texts
self.keyphrases = keyphrases
self.keys = keys
self.max_seq_length = max_seq_length
self.tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
self.sep_vid = self.tokenizer.vocab['[SEP]']
self.cls_vid = self.tokenizer.vocab['[CLS]']
self.pad_vid = self.tokenizer.vocab['[PAD]']
def __len__(self):
return len(self.texts)
def __getitem__(self, index) -> Mapping[str, torch.Tensor]:
def _find_inclusions(list_a, list_b):
return [x for x in range(len(list_a)) if list_a[x:x + len(list_b)] == list_b]
x, y = self.texts[index], self.keyphrases[index]
x_encoded = self.tokenizer.encode(
x,
add_special_tokens=True,
max_length=self.max_seq_length,
return_tensors='pt',
).squeeze(0)
true_seq_length = x_encoded.size(0)
pad_size = self.max_seq_length - true_seq_length
pad_ids = torch.Tensor([self.pad_vid] * pad_size).long()
x_tensor = torch.cat((x_encoded, pad_ids))
attention_mask = torch.ones_like(x_encoded, dtype=int)
mask_pad = torch.zeros_like(pad_ids, dtype=int)
attention_mask = torch.cat((attention_mask, mask_pad))
x_list = x_encoded.tolist()
y_encoded = [self.tokenizer.encode(x) for x in y]
y_start_pos = [_find_inclusions(x_list, y) for y in y_encoded]
y_positions = [
list(range(x[0], x[0] + len(y))) for x, y in zip(y_start_pos, y_encoded) if x
]
y_positions = [item for sublist in y_positions for item in sublist]
labels = torch.zeros_like(x_tensor)
labels[y_positions] = 1
return {
self.keys.input_ids: x_tensor,
self.keys.targets: labels,
self.keys.attention_mask: attention_mask
}