Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prototypical networks in Flair #2627

Merged
merged 26 commits into from
Feb 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
627db34
Refactor decoder to DefaultClassifier
alanakbik Jan 27, 2022
75ded15
Slightly changed definition of _init_model_with_state_dict and _get_s…
plonerma Jan 31, 2022
caa80f7
Added possibility to use custom decoder in DefaultClassifier
plonerma Jan 31, 2022
0342e88
Fixed variable name
plonerma Feb 1, 2022
04eb35c
Added prototypical decoder
plonerma Feb 1, 2022
a291500
Improved patch for train function
plonerma Feb 1, 2022
fd45342
Fixed wrong var name
plonerma Feb 4, 2022
937b558
Removed unneded params for super (class not defined)
plonerma Feb 4, 2022
aeaca36
Merge branch 'custom_decoder_refactor' into prototype_decoder_refactor
plonerma Feb 4, 2022
74a8c2b
Adapted models _get_state_dict and _init_model_with_state_dict method…
plonerma Feb 4, 2022
248d19e
Removed wrong char
plonerma Feb 8, 2022
1bc0b93
Reformat
alanakbik Feb 9, 2022
f586e2a
Merge branch 'master' into prototype_decoder_refactor
alanakbik Feb 9, 2022
e4b70b9
Fix memory issues
alanakbik Feb 9, 2022
f9efd39
Change logits variable name
alanakbik Feb 9, 2022
afc4e31
Merge branch 'custom_decoder_refactor' into prototype_decoder_refactor
plonerma Feb 9, 2022
46e53f0
Merge branch 'prototype_decoder_refactor' of github.com:flairNLP/flai…
plonerma Feb 9, 2022
0dc57d4
Reformat
alanakbik Feb 9, 2022
6f18860
Fix mypy
alanakbik Feb 9, 2022
6655b7f
Fix errors in unit tests
alanakbik Feb 9, 2022
5139e10
Fix flake8 errors
alanakbik Feb 9, 2022
e74e382
Add kwargs back in
alanakbik Feb 9, 2022
790b1dd
Black formatting
alanakbik Feb 9, 2022
6fdbcea
Fix forward pass for empty batches
alanakbik Feb 9, 2022
1f1cd77
Add length check to decoder
alanakbik Feb 9, 2022
b462fdf
Add length check to predict
alanakbik Feb 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions flair/models/dependency_parser_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def _obtain_labels_(

def _get_state_dict(self):
model_state = {
"state_dict": self.state_dict(),
**super()._get_state_dict(),
"token_embeddings": self.token_embeddings,
"use_rnn": self.use_rnn,
"lstm_hidden_size": self.lstm_hidden_size,
Expand All @@ -385,10 +385,10 @@ def _get_state_dict(self):
}
return model_state

@staticmethod
def _init_model_with_state_dict(state):

model = DependencyParser(
@classmethod
def _init_model_with_state_dict(cls, state, **kwargs):
return super()._init_model_with_state_dict(
state,
token_embeddings=state["token_embeddings"],
relations_dictionary=state["relations_dictionary"],
use_rnn=state["use_rnn"],
Expand All @@ -398,9 +398,8 @@ def _init_model_with_state_dict(state):
lstm_layers=state["lstm_layers"],
mlp_dropout=state["mlp_dropout"],
lstm_dropout=state["lstm_dropout"],
**kwargs,
)
model.load_state_dict(state["state_dict"])
return model

@property
def label_type(self):
Expand Down
14 changes: 7 additions & 7 deletions flair/models/diagnosis/distance_prediction_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def forward(self, sentence: Sentence):

def _get_state_dict(self):
model_state = {
"state_dict": self.state_dict(),
**super()._get_state_dict(),
"word_embeddings": self.word_embeddings,
"max_distance": self.max_distance,
"beta": self.beta,
Expand All @@ -156,23 +156,23 @@ def _get_state_dict(self):
}
return model_state

@staticmethod
def _init_model_with_state_dict(state):
@classmethod
def _init_model_with_state_dict(cls, state, **kwargs):

beta = 1.0 if "beta" not in state.keys() else state["beta"]
weight = 1 if "loss_max_weight" not in state.keys() else state["loss_max_weight"]

model = DistancePredictor(
return super()._init_model_with_state_dict(
state,
word_embeddings=state["word_embeddings"],
max_distance=state["max_distance"],
beta=beta,
loss_max_weight=weight,
regression=state["regression"],
regr_loss_step=state["regr_loss_step"],
**kwargs,
)

model.load_state_dict(state["state_dict"])
return model

# So far only one sentence allowed
# If list of sentences is handed the function works with the first sentence of the list
def forward_loss(self, data_points: Union[List[Sentence], Sentence]) -> torch.Tensor:
Expand Down
53 changes: 20 additions & 33 deletions flair/models/entity_linker_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import List, Optional, Union

import torch
import torch.nn as nn

import flair.embeddings
import flair.nn
Expand Down Expand Up @@ -40,7 +39,13 @@ def __init__(
:param label_type: name of the label you use.
"""

super(EntityLinker, self).__init__(label_dictionary, **classifierargs)
super(EntityLinker, self).__init__(
label_dictionary=label_dictionary,
final_embedding_size=word_embeddings.embedding_length * 2
if pooling_operation == "first&last"
else word_embeddings.embedding_length,
**classifierargs,
)

self.word_embeddings = word_embeddings
self.pooling_operation = pooling_operation
Expand All @@ -55,16 +60,6 @@ def __init__(
if dropout > 0.0:
self.dropout = torch.nn.Dropout(dropout)

# if we concatenate the embeddings we need double input size in our linear layer
if self.pooling_operation == "first&last":
self.decoder = nn.Linear(2 * self.word_embeddings.embedding_length, len(self.label_dictionary)).to(
flair.device
)
else:
self.decoder = nn.Linear(self.word_embeddings.embedding_length, len(self.label_dictionary)).to(flair.device)

nn.init.xavier_uniform_(self.decoder.weight)

cases = {
"average": self.emb_mean,
"first": self.emb_first,
Expand Down Expand Up @@ -110,13 +105,10 @@ def forward_pass(
span_labels = []
sentences_to_spans = []
empty_label_candidates = []
embedded_entity_pairs = None

# if the entire batch has no sentence with candidates, return empty
if len(filtered_sentences) == 0:
scores = None

# otherwise, embed sentence and send through prediction head
else:
# embed sentences and send through prediction head
if len(filtered_sentences) > 0:
# embed all tokens
self.word_embeddings.embed(filtered_sentences)

Expand Down Expand Up @@ -152,23 +144,19 @@ def forward_pass(
empty_label_candidates.append(candidate)

if len(embedding_list) > 0:
embedding_tensor = torch.cat(embedding_list, 0).to(flair.device)
embedded_entity_pairs = torch.cat(embedding_list, 0)

if self.use_dropout:
embedding_tensor = self.dropout(embedding_tensor)

scores = self.decoder(embedding_tensor)
else:
scores = None
embedded_entity_pairs = self.dropout(embedded_entity_pairs)

if return_label_candidates:
return scores, span_labels, sentences_to_spans, empty_label_candidates
return embedded_entity_pairs, span_labels, sentences_to_spans, empty_label_candidates

return scores, span_labels
return embedded_entity_pairs, span_labels

def _get_state_dict(self):
model_state = {
"state_dict": self.state_dict(),
**super()._get_state_dict(),
"word_embeddings": self.word_embeddings,
"label_type": self.label_type,
"label_dictionary": self.label_dictionary,
Expand All @@ -177,19 +165,18 @@ def _get_state_dict(self):
}
return model_state

@staticmethod
def _init_model_with_state_dict(state):
model = EntityLinker(
@classmethod
def _init_model_with_state_dict(cls, state, **kwargs):
return super()._init_model_with_state_dict(
state,
word_embeddings=state["word_embeddings"],
label_dictionary=state["label_dictionary"],
label_type=state["label_type"],
pooling_operation=state["pooling_operation"],
loss_weights=state["loss_weights"] if "loss_weights" in state else {"<unk>": 0.3},
**kwargs,
)

model.load_state_dict(state["state_dict"])
return model

@property
def label_type(self):
return self._label_type
11 changes: 6 additions & 5 deletions flair/models/lemmatizer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ def predict(

def _get_state_dict(self):
model_state = {
"state_dict": self.state_dict(),
**super()._get_state_dict(),
"embeddings": self.encoder_embeddings,
"rnn_input_size": self.rnn_input_size,
"rnn_hidden_size": self.rnn_hidden_size,
Expand All @@ -660,8 +660,10 @@ def _get_state_dict(self):

return model_state

def _init_model_with_state_dict(state):
model = Lemmatizer(
@classmethod
def _init_model_with_state_dict(cls, state, **kwargs):
return super()._init_model_with_state_dict(
state,
embeddings=state["embeddings"],
encode_characters=state["encode_characters"],
rnn_input_size=state["rnn_input_size"],
Expand All @@ -676,9 +678,8 @@ def _init_model_with_state_dict(state):
start_symbol_for_encoding=state["start_symbol"],
end_symbol_for_encoding=state["end_symbol"],
bidirectional_encoding=state["bidirectional_encoding"],
**kwargs,
)
model.load_state_dict(state["state_dict"])
return model

def _print_predictions(self, batch, gold_label_type):
lines = []
Expand Down
37 changes: 13 additions & 24 deletions flair/models/pairwise_classification_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,28 +33,20 @@ def __init__(
:param loss_weights: Dictionary of weights for labels for the loss function
(if any label's weight is unspecified it will default to 1.0)
"""
super().__init__(**classifierargs)
super().__init__(
**classifierargs,
final_embedding_size=2 * document_embeddings.embedding_length
if embed_separately
else document_embeddings.embedding_length,
)

self.document_embeddings: flair.embeddings.DocumentEmbeddings = document_embeddings

self._label_type = label_type

self.embed_separately = embed_separately

# if embed_separately == True the linear layer needs twice the length of the embeddings as input size
# since we concatenate the embeddings of the two DataPoints in the DataPairs
if self.embed_separately:
self.decoder = torch.nn.Linear(
2 * self.document_embeddings.embedding_length,
len(self.label_dictionary),
).to(flair.device)

torch.nn.init.xavier_uniform_(self.decoder.weight)

else:
# representation for both sentences
self.decoder = torch.nn.Linear(self.document_embeddings.embedding_length, len(self.label_dictionary))

if not self.embed_separately:
# set separator to concatenate two sentences
self.sep = " "
if isinstance(
Expand All @@ -66,8 +58,6 @@ def __init__(
else:
self.sep = " [SEP] "

torch.nn.init.xavier_uniform_(self.decoder.weight)

# auto-spawn on GPU if available
self.to(flair.device)

Expand Down Expand Up @@ -136,7 +126,7 @@ def forward_pass(

def _get_state_dict(self):
model_state = {
"state_dict": self.state_dict(),
**super()._get_state_dict(),
"document_embeddings": self.document_embeddings,
"label_dictionary": self.label_dictionary,
"label_type": self.label_type,
Expand All @@ -147,10 +137,10 @@ def _get_state_dict(self):
}
return model_state

@staticmethod
def _init_model_with_state_dict(state):

model = TextPairClassifier(
@classmethod
def _init_model_with_state_dict(cls, state, **kwargs):
return super()._init_model_with_state_dict(
state,
document_embeddings=state["document_embeddings"],
label_dictionary=state["label_dictionary"],
label_type=state["label_type"],
Expand All @@ -160,6 +150,5 @@ def _init_model_with_state_dict(state):
else state["multi_label_threshold"],
loss_weights=state["weight_dict"],
embed_separately=state["embed_separately"],
**kwargs,
)
model.load_state_dict(state["state_dict"])
return model
Loading