Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding the Reddit Entity Linking Dataset #2148

Merged
merged 1 commit into from
Mar 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flair/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from .sequence_labeling import WSD_UFSAC
from .sequence_labeling import WNUT_2020_NER
from .sequence_labeling import XTREME
from .sequence_labeling import REDDIT_EL_GOLD

# Expose all document classification datasets
from .document_classification import ClassificationCorpus
Expand Down
221 changes: 221 additions & 0 deletions flair/datasets/sequence_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from os import listdir
import zipfile
from zipfile import ZipFile
import csv


import flair
Expand Down Expand Up @@ -3518,3 +3519,223 @@ def xtreme_to_simple_ner_annotation(data_file: Union[str, Path]):
else:
liste = line.split()
f.write(liste[0].split(':', 1)[1] + ' ' + liste[1] + '\n')


class REDDIT_EL_GOLD(ColumnCorpus):
def __init__(
self,
base_path: Union[str, Path] = None,
in_memory: bool = True,
document_as_sequence: bool = False,
**corpusargs,
):
"""
Initialize the Reddit Entity Linking corpus containing gold annotations only (https://arxiv.org/abs/2101.01228v2) in the NER-like column format.
The first time you call this constructor it will automatically download the dataset.
:param base_path: Default is None, meaning that corpus gets auto-downloaded and loaded. You can override this
to point to a different folder but typically this should not be necessary.
:param in_memory: If True, keeps dataset in memory giving speedups in training.
:param document_as_sequence: If True, all sentences of a document are read into a single Sentence object
"""
if type(base_path) == str:
base_path: Path = Path(base_path)

# column format
columns = {0: "text", 1: "ner"}

# this dataset name
dataset_name = self.__class__.__name__.lower()

# default dataset folder is the cache root
if not base_path:
base_path = Path(flair.cache_root) / "datasets"
data_folder = base_path / dataset_name

# download and parse data if necessary
reddit_el_path = "https://zenodo.org/record/3970806/files/reddit_el.zip"
corpus_file_name = "reddit_el_gold.txt"
parsed_dataset = data_folder / corpus_file_name

if not parsed_dataset.exists():
reddit_el_zip = cached_path(f"{reddit_el_path}", Path("datasets") / dataset_name)
unpack_file(reddit_el_zip, data_folder, "zip", False)

with open(data_folder / corpus_file_name, "w") as txtout:

# First parse the post titles
with open(data_folder / "posts.tsv", "r") as tsvin1, open(data_folder / "gold_post_annotations.tsv", "r") as tsvin2:

posts = csv.reader(tsvin1, delimiter="\t")
self.post_annotations = csv.reader(tsvin2, delimiter="\t")
self.curr_annot = next(self.post_annotations)

for row in posts: # Go through all the post titles

txtout.writelines("-DOCSTART-\n") # Start each post with a -DOCSTART- token

# Keep track of how many and which entity mentions does a given post title have
link_annots = [] # [start pos, end pos, wiki page title] of an entity mention

# Check if the current post title has an entity link and parse accordingly
if row[0] == self.curr_annot[0]:

link_annots.append((int(self.curr_annot[4]), int(self.curr_annot[5]), self.curr_annot[3]))
link_annots = self._fill_annot_array(link_annots, row[0], post_flag = True)

# Post titles with entity mentions (if any) are handled via this function
self._text_to_cols(Sentence(row[2], use_tokenizer = True), link_annots, txtout)
else:
self._text_to_cols(Sentence(row[2], use_tokenizer = True), link_annots, txtout)

# Then parse the comments
with open(data_folder / "comments.tsv", "r") as tsvin3, open(data_folder / "gold_comment_annotations.tsv", "r") as tsvin4:

self.comments = csv.reader(tsvin3, delimiter="\t")
self.comment_annotations = csv.reader(tsvin4, delimiter="\t")
self.curr_annot = next(self.comment_annotations)
self.curr_row = next(self.comments)
self.stop_iter = False

# Iterate over the comments.tsv file, until the end is reached
while not self.stop_iter:

txtout.writelines("-DOCSTART-\n") # Start each comment thread with a -DOCSTART- token

# Keep track of the current comment thread and its corresponding key, on which the annotations are matched.
# Each comment thread is handled as one 'document'.
self.curr_comm = self.curr_row[4]
comm_key = self.curr_row[0]

# Python's csv package for some reason fails to correctly parse a handful of rows inside the comments.tsv file.
# This if-condition is needed to handle this problem.
if comm_key in {"en5rf4c", "es3ia8j", "es3lrmw"}:
if comm_key == "en5rf4c":
self.parsed_row = (r.split("\t") for r in self.curr_row[4].split("\n"))
self.curr_comm = next(self.parsed_row)
self._fill_curr_comment(fix_flag = True)
# In case we are dealing with properly parsed rows, proceed with a regular parsing procedure
else:
self._fill_curr_comment(fix_flag = False)

link_annots = [] # [start pos, end pos, wiki page title] of an entity mention

# Check if the current comment thread has an entity link and parse accordingly, same as with post titles above
if comm_key == self.curr_annot[0]:
link_annots.append((int(self.curr_annot[4]), int(self.curr_annot[5]), self.curr_annot[3]))
link_annots = self._fill_annot_array(link_annots, comm_key, post_flag = False)
self._text_to_cols(Sentence(self.curr_comm, use_tokenizer = True), link_annots, txtout)
else:
# In two of the comment thread a case of capital letter spacing occurs, which the SegtokTokenizer cannot properly handle.
# The following if-elif condition handles these two cases and as result writes full capitalized words in each corresponding row,
# and not just single letters into single rows.
if comm_key == "dv74ybb":
self.curr_comm = " ".join([word.replace(" ", "") for word in self.curr_comm.split(" ")])
elif comm_key == "eci2lut":
self.curr_comm = (self.curr_comm[:18] + self.curr_comm[18:27].replace(" ", "") + self.curr_comm[27:55] +
self.curr_comm[55:68].replace(" ", "") + self.curr_comm[68:85] + self.curr_comm[85:92].replace(" ", "") +
self.curr_comm[92:])

self._text_to_cols(Sentence(self.curr_comm, use_tokenizer = True), link_annots, txtout)

super(REDDIT_EL_GOLD, self).__init__(
data_folder,
columns,
train_file=corpus_file_name,
column_delimiter="\t",
in_memory=in_memory,
document_separator_token=None if not document_as_sequence else "-DOCSTART-",
**corpusargs,
)

def _text_to_cols(self, sentence: Sentence, links: list, outfile):
"""
Convert a tokenized sentence into column format
:param sentence: Flair Sentence object containing a tokenized post title or comment thread
:param links: array containing information about the starting and ending position of an entity mention, as well
as its corresponding wiki tag
:param outfile: file, to which the output is written
"""
for i in range(0, len(sentence)):
# If there are annotated entity mentions for given post title or a comment thread
if links:
# Keep track which is the correct corresponding entity link, in cases where there is >1 link in a sentence
link_index = [j for j,v in enumerate(links) if (sentence[i].start_pos >= v[0] and sentence[i].end_pos <= v[1])]
# Write the token with a corresponding tag to file
try:
if any(sentence[i].start_pos == v[0] and sentence[i].end_pos == v[1] for j,v in enumerate(links)):
outfile.writelines(sentence[i].text + "\tS-Link:" + links[link_index[0]][2] + "\n")
elif any(sentence[i].start_pos == v[0] and sentence[i].end_pos != v[1] for j,v in enumerate(links)):
outfile.writelines(sentence[i].text + "\tB-Link:" + links[link_index[0]][2] + "\n")
elif any(sentence[i].start_pos >= v[0] and sentence[i].end_pos <= v[1] for j,v in enumerate(links)):
outfile.writelines(sentence[i].text + "\tI-Link:" + links[link_index[0]][2] + "\n")
else:
outfile.writelines(sentence[i].text + "\tO\n")
# IndexError is raised in cases when there is exactly one link in a sentence, therefore can be dismissed
except IndexError:
pass

# If a comment thread or a post title has no entity link, all tokens are assigned the O tag
else:
outfile.writelines(sentence[i].text + "\tO\n")

# Prevent writing empty lines if e.g. a quote comes after a dot or initials are tokenized
# incorrectly, in order to keep the desired format (empty line as a sentence separator).
try:
if ((sentence[i].text in {".", "!", "?", "!*"}) and
(sentence[i+1].text not in {'"', '“', "'", "''", "!", "?", ";)", "."}) and
("." not in sentence[i-1].text)):
outfile.writelines("\n")
except IndexError:
# Thrown when the second check above happens, but the last token of a sentence is reached.
# Indicates that the EOS punctuaion mark is present, therefore an empty line needs to be written below.
outfile.writelines("\n")

# If there is no punctuation mark indicating EOS, an empty line is still needed after the EOS
if sentence[-1].text not in {".", "!", "?"}:
outfile.writelines("\n")

def _fill_annot_array(self, annot_array: list, key: str, post_flag: bool) -> list:
"""
Fills the array containing information about the entity mention annotations, used in the _text_to_cols method
:param annot_array: array to be filled
:param key: reddit id, on which the post title/comment thread is matched with its corresponding annotation
:param post_flag: flag indicating whether the annotations are collected for the post titles (=True)
or comment threads (=False)
"""
next_annot = None
while True:
# Check if further annotations belong to the current post title or comment thread as well
try:
next_annot = next(self.post_annotations) if post_flag else next(self.comment_annotations)
if next_annot[0] == key:
annot_array.append((int(next_annot[4]), int(next_annot[5]), next_annot[3]))
else:
self.curr_annot = next_annot
break
# Stop when the end of an annotation file is reached
except StopIteration:
break
return annot_array

def _fill_curr_comment(self, fix_flag: bool):
"""
Extends the string containing the current comment thread, which is passed to _text_to_cols method, when the
comments are parsed.
:param fix_flag: flag indicating whether the method is called when the incorrectly imported rows are parsed (=True)
or regular rows (=False)
"""
next_row = None
while True:
# Check if further annotations belong to the current sentence as well
try:
next_row = next(self.comments) if not fix_flag else next(self.parsed_row)
if len(next_row) < 2:
# 'else " "' is needed to keep the proper token positions (for accordance with annotations)
self.curr_comm += next_row[0] if any(next_row) else " "
else:
self.curr_row = next_row
break
except StopIteration: # When the end of the comments.tsv file is reached
self.curr_row = next_row
self.stop_iter = True if not fix_flag else False
break