-
Notifications
You must be signed in to change notification settings - Fork 274
/
Copy pathmydatasets.py
129 lines (106 loc) · 5.4 KB
/
mydatasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import re
import os
import random
import tarfile
import urllib
from torchtext import data
class TarDataset(data.Dataset):
"""Defines a Dataset loaded from a downloadable tar archive.
Attributes:
url: URL where the tar archive can be downloaded.
filename: Filename of the downloaded tar archive.
dirname: Name of the top-level directory within the zip archive that
contains the data files.
"""
@classmethod
def download_or_unzip(cls, root):
path = os.path.join(root, cls.dirname)
if not os.path.isdir(path):
tpath = os.path.join(root, cls.filename)
if not os.path.isfile(tpath):
print('downloading')
urllib.request.urlretrieve(cls.url, tpath)
with tarfile.open(tpath, 'r') as tfile:
print('extracting')
def is_within_directory(directory, target):
abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)
prefix = os.path.commonprefix([abs_directory, abs_target])
return prefix == abs_directory
def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
for member in tar.getmembers():
member_path = os.path.join(path, member.name)
if not is_within_directory(path, member_path):
raise Exception("Attempted Path Traversal in Tar File")
tar.extractall(path, members, numeric_owner=numeric_owner)
safe_extract(tfile, root)
return os.path.join(path, '')
class MR(TarDataset):
url = 'https://www.cs.cornell.edu/people/pabo/movie-review-data/rt-polaritydata.tar.gz'
filename = 'rt-polaritydata.tar.gz'
dirname = 'rt-polaritydata'
@staticmethod
def sort_key(ex):
return len(ex.text)
def __init__(self, text_field, label_field, path=None, examples=None, **kwargs):
"""Create an MR dataset instance given a path and fields.
Arguments:
text_field: The field that will be used for text data.
label_field: The field that will be used for label data.
path: Path to the data file.
examples: The examples contain all the data.
Remaining keyword arguments: Passed to the constructor of
data.Dataset.
"""
def clean_str(string):
"""
Tokenization/string cleaning for all datasets except for SST.
Original taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py
"""
string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)
string = re.sub(r"\'s", " \'s", string)
string = re.sub(r"\'ve", " \'ve", string)
string = re.sub(r"n\'t", " n\'t", string)
string = re.sub(r"\'re", " \'re", string)
string = re.sub(r"\'d", " \'d", string)
string = re.sub(r"\'ll", " \'ll", string)
string = re.sub(r",", " , ", string)
string = re.sub(r"!", " ! ", string)
string = re.sub(r"\(", " \( ", string)
string = re.sub(r"\)", " \) ", string)
string = re.sub(r"\?", " \? ", string)
string = re.sub(r"\s{2,}", " ", string)
return string.strip()
text_field.tokenize = lambda x: clean_str(x).split()
fields = [('text', text_field), ('label', label_field)]
if examples is None:
path = self.dirname if path is None else path
examples = []
with open(os.path.join(path, 'rt-polarity.neg'), errors='ignore') as f:
examples += [
data.Example.fromlist([line, 'negative'], fields) for line in f]
with open(os.path.join(path, 'rt-polarity.pos'), errors='ignore') as f:
examples += [
data.Example.fromlist([line, 'positive'], fields) for line in f]
super(MR, self).__init__(examples, fields, **kwargs)
@classmethod
def splits(cls, text_field, label_field, dev_ratio=.1, shuffle=True, root='.', **kwargs):
"""Create dataset objects for splits of the MR dataset.
Arguments:
text_field: The field that will be used for the sentence.
label_field: The field that will be used for label data.
dev_ratio: The ratio that will be used to get split validation dataset.
shuffle: Whether to shuffle the data before split.
root: The root directory that the dataset's zip archive will be
expanded into; therefore the directory in whose trees
subdirectory the data files will be stored.
train: The filename of the train data. Default: 'train.txt'.
Remaining keyword arguments: Passed to the splits method of
Dataset.
"""
path = cls.download_or_unzip(root)
examples = cls(text_field, label_field, path=path, **kwargs).examples
if shuffle: random.shuffle(examples)
dev_index = -1 * int(dev_ratio*len(examples))
return (cls(text_field, label_field, examples=examples[:dev_index]),
cls(text_field, label_field, examples=examples[dev_index:]))