Skip to content

Commit

Permalink
Merge pull request #2149 from flairNLP/text-pair-dataset
Browse files Browse the repository at this point in the history
Add option to load text pairs with CSVClassificationCorpus
  • Loading branch information
alanakbik authored Mar 13, 2021
2 parents b415155 + ad2e1ed commit 51c1b5d
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 117 deletions.
86 changes: 52 additions & 34 deletions flair/datasets/document_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
Corpus,
Token,
FlairDataset,
Tokenizer
Tokenizer, DataPair
)
from flair.tokenization import SegtokTokenizer, SpaceTokenizer
from flair.datasets.base import find_train_dev_test_files
Expand Down Expand Up @@ -454,9 +454,12 @@ def __init__(

# most data sets have the token text in the first column, if not, pass 'text' as column
self.text_columns: List[int] = []
self.pair_columns: List[int] = []
for column in column_name_map:
if column_name_map[column] == "text":
self.text_columns.append(column)
if column_name_map[column] == "pair":
self.pair_columns.append(column)

with open(self.path_to_file, encoding=encoding) as csv_file:

Expand Down Expand Up @@ -488,33 +491,61 @@ def __init__(

if self.in_memory:

text = " ".join(
[row[text_column] for text_column in self.text_columns]
)

if self.max_chars_per_doc > 0:
text = text[: self.max_chars_per_doc]

sentence = Sentence(text, use_tokenizer=self.tokenizer)

for column in self.column_name_map:
column_value = row[column]
if (
self.column_name_map[column].startswith("label")
and column_value
):
if column_value != self.no_class_label:
sentence.add_label(label_type, column_value)
sentence = self._make_labeled_data_point(row)

if 0 < self.max_tokens_per_doc < len(sentence):
sentence.tokens = sentence.tokens[: self.max_tokens_per_doc]
self.sentences.append(sentence)

else:
self.raw_data.append(row)

self.total_sentence_count += 1

def _make_labeled_data_point(self, row):

# make sentence from text (and filter for length)
text = " ".join(
[row[text_column] for text_column in self.text_columns]
)

if self.max_chars_per_doc > 0:
text = text[: self.max_chars_per_doc]

sentence = Sentence(text, use_tokenizer=self.tokenizer)

if 0 < self.max_tokens_per_doc < len(sentence):
sentence.tokens = sentence.tokens[: self.max_tokens_per_doc]

# if a pair column is defined, make a sentence pair object
if len(self.pair_columns) > 0:

text = " ".join(
[row[pair_column] for pair_column in self.pair_columns]
)

if self.max_chars_per_doc > 0:
text = text[: self.max_chars_per_doc]

pair = Sentence(text, use_tokenizer=self.tokenizer)

if 0 < self.max_tokens_per_doc < len(sentence):
pair.tokens = pair.tokens[: self.max_tokens_per_doc]

data_point = DataPair(first=sentence, second=pair)

else:
data_point = sentence

for column in self.column_name_map:
column_value = row[column]
if (
self.column_name_map[column].startswith("label")
and column_value
):
if column_value != self.no_class_label:
data_point.add_label(self.label_type, column_value)

return data_point

def is_in_memory(self) -> bool:
return self.in_memory

Expand All @@ -527,20 +558,7 @@ def __getitem__(self, index: int = 0) -> Sentence:
else:
row = self.raw_data[index]

text = " ".join([row[text_column] for text_column in self.text_columns])

if self.max_chars_per_doc > 0:
text = text[: self.max_chars_per_doc]

sentence = Sentence(text, use_tokenizer=self.tokenizer)
for column in self.column_name_map:
column_value = row[column]
if self.column_name_map[column].startswith("label") and column_value:
if column_value != self.no_class_label:
sentence.add_label(self.label_type, column_value)

if 0 < self.max_tokens_per_doc < len(sentence):
sentence.tokens = sentence.tokens[: self.max_tokens_per_doc]
sentence = self._make_labeled_data_point(row)

return sentence

Expand Down
5 changes: 2 additions & 3 deletions flair/datasets/sequence_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3526,7 +3526,6 @@ def __init__(
self,
base_path: Union[str, Path] = None,
in_memory: bool = True,
document_as_sequence: bool = False,
**corpusargs,
):
"""
Expand Down Expand Up @@ -3571,7 +3570,7 @@ def __init__(

for row in posts: # Go through all the post titles

txtout.writelines("-DOCSTART-\n") # Start each post with a -DOCSTART- token
txtout.writelines("-DOCSTART-\n\n") # Start each post with a -DOCSTART- token

# Keep track of how many and which entity mentions does a given post title have
link_annots = [] # [start pos, end pos, wiki page title] of an entity mention
Expand Down Expand Up @@ -3643,7 +3642,7 @@ def __init__(
train_file=corpus_file_name,
column_delimiter="\t",
in_memory=in_memory,
document_separator_token=None if not document_as_sequence else "-DOCSTART-",
document_separator_token="-DOCSTART-",
**corpusargs,
)

Expand Down
Loading

0 comments on commit 51c1b5d

Please sign in to comment.