Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-2426: multi-label tars #2430

Merged
merged 2 commits into from
Sep 13, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 46 additions & 21 deletions flair/models/tars_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,14 @@ def get_current_label_dictionary(self):
def get_current_label_type(self):
return self._task_specific_attributes[self._current_task]['label_type']

def is_current_task_multi_label(self):
return self._task_specific_attributes[self._current_task]['multi_label']

def add_and_switch_to_new_task(self,
task_name,
label_dictionary: Union[List, Set, Dictionary, str],
label_type: str,
multi_label: bool = True,
force_switch: bool = False,
):
"""
Expand All @@ -179,13 +183,13 @@ def add_and_switch_to_new_task(self,
size and negative sampling. This method does not store the resultant model onto disk.
:param task_name: a string depicting the name of the task
:param label_dictionary: dictionary of the labels you want to predict
:param multi_label: auto-detect if a corpus label dictionary is provided. Defaults to True otherwise
:param multi_label_threshold: If multi-label you can set the threshold to make predictions
:param label_type: string to identify the label type ('ner', 'sentiment', etc.)
:param multi_label: whether this task is a multi-label prediction problem
:param force_switch: if True, will overwrite existing task with same name
"""
if task_name in self._task_specific_attributes and not force_switch:
log.warning("Task `%s` already exists in TARS model. Switching to it.", task_name)
else:

# make label dictionary if no Dictionary object is passed
if isinstance(label_dictionary, Dictionary):
label_dictionary = label_dictionary.get_items()
Expand All @@ -202,7 +206,9 @@ def add_and_switch_to_new_task(self,
else:
tag_dictionary.add_item(tag)

self._task_specific_attributes[task_name] = {'label_dictionary': tag_dictionary, 'label_type': label_type}
self._task_specific_attributes[task_name] = {'label_dictionary': tag_dictionary,
'label_type': label_type,
'multi_label': multi_label}

self.switch_to_task(task_name)

Expand Down Expand Up @@ -261,27 +267,23 @@ def predict_zero_shot(self,
log.warning("Provided candidate_label_set is empty")
return

label_dictionary = Dictionary(add_unk=False)
label_dictionary.multi_label = multi_label

# make list if only one candidate label is passed
if isinstance(candidate_label_set, str):
candidate_label_set = {candidate_label_set}

# if list is passed, convert to set
if not isinstance(candidate_label_set, set):
candidate_label_set = set(candidate_label_set)

# create label dictionary
label_dictionary = Dictionary(add_unk=False)
for label in candidate_label_set:
label_dictionary.add_item(label)

# note current task
existing_current_task = self._current_task

# create a temporary task
self.add_and_switch_to_new_task("ZeroShot",
label_dictionary,
'-'.join(label_dictionary.get_items()))
self.add_and_switch_to_new_task(task_name="ZeroShot",
label_dictionary=label_dictionary,
label_type='-'.join(label_dictionary.get_items()),
multi_label=multi_label)

try:
# make zero shot predictions
Expand Down Expand Up @@ -714,8 +716,6 @@ def _init_model_with_state_dict(state):
# set all task information
model._task_specific_attributes = state["task_specific_attributes"]

print(model._task_specific_attributes)

# linear layers of internal classifier
model.load_state_dict(state["state_dict"])
return model
Expand Down Expand Up @@ -746,6 +746,8 @@ def predict(
label_name: Optional[str] = None,
return_loss=False,
embedding_storage_mode="none",
label_threshold: float = 0.5,
multi_label: Optional[bool] = None,
):
"""
Predict sequence tags for Named Entity Recognition task
Expand All @@ -761,9 +763,12 @@ def predict(
you wish to not only predict, but also keep the generated embeddings in CPU or GPU memory respectively.
'gpu' to store embeddings in GPU memory.
"""
if label_name == None:
if not label_name:
label_name = self.get_current_label_type()

if multi_label is None:
multi_label = self.is_current_task_multi_label()

# with torch.no_grad():
if not sentences:
return sentences
Expand Down Expand Up @@ -815,19 +820,39 @@ def predict(

all_labels = [label.decode("utf-8") for label in self.get_current_label_dictionary().idx2item]

best_label = None
for label in all_labels:
tars_sentence = self._get_tars_formatted_sentence(label, sentence)

loss_and_count = self.tars_model.predict(tars_sentence,
label_name=label_name,
return_loss=True)
return_loss=True,
return_probabilities_for_all_classes=True
if label_threshold < 0.5 else False,
)

overall_loss += loss_and_count[0].item()
overall_count += loss_and_count[1]

predicted_tars_label = tars_sentence.get_labels(label_name)[0]
if predicted_tars_label.value == self.LABEL_MATCH:
sentence.add_label(label_name, label, predicted_tars_label.score)
# add all labels that according to TARS match the text and are above threshold
for predicted_tars_label in tars_sentence.get_labels(label_name):
if predicted_tars_label.value == self.LABEL_MATCH \
and predicted_tars_label.score > label_threshold:
# do not add labels below confidence threshold
sentence.add_label(label_name, label, predicted_tars_label.score)

# only use label with highest confidence if enforcing single-label predictions
if not multi_label:
if len(sentence.get_labels()) > 0:

# get all label scores and do an argmax to get the best label
label_scores = torch.tensor([label.score for label in sentence.get_labels(label_name)],
dtype=torch.float)
best_label = sentence.get_labels(label_name)[torch.argmax(label_scores)]

# remove previously added labels and only add the best label
sentence.remove_labels(label_name)
sentence.add_label(typename=label_name, value=best_label.value, score=best_label.score)

# clearing token embeddings to save memory
store_embeddings(batch, storage_mode=embedding_storage_mode)
Expand Down