diff --git a/flair/models/dependency_parser_model.py b/flair/models/dependency_parser_model.py index 4754e28790..4ddfa74120 100644 --- a/flair/models/dependency_parser_model.py +++ b/flair/models/dependency_parser_model.py @@ -372,7 +372,7 @@ def _obtain_labels_( def _get_state_dict(self): model_state = { - "state_dict": self.state_dict(), + **super()._get_state_dict(), "token_embeddings": self.token_embeddings, "use_rnn": self.use_rnn, "lstm_hidden_size": self.lstm_hidden_size, @@ -385,10 +385,10 @@ def _get_state_dict(self): } return model_state - @staticmethod - def _init_model_with_state_dict(state): - - model = DependencyParser( + @classmethod + def _init_model_with_state_dict(cls, state, **kwargs): + return super()._init_model_with_state_dict( + state, token_embeddings=state["token_embeddings"], relations_dictionary=state["relations_dictionary"], use_rnn=state["use_rnn"], @@ -398,9 +398,8 @@ def _init_model_with_state_dict(state): lstm_layers=state["lstm_layers"], mlp_dropout=state["mlp_dropout"], lstm_dropout=state["lstm_dropout"], + **kwargs, ) - model.load_state_dict(state["state_dict"]) - return model @property def label_type(self): diff --git a/flair/models/diagnosis/distance_prediction_model.py b/flair/models/diagnosis/distance_prediction_model.py index 2c74aee85c..e98de954ff 100644 --- a/flair/models/diagnosis/distance_prediction_model.py +++ b/flair/models/diagnosis/distance_prediction_model.py @@ -146,7 +146,7 @@ def forward(self, sentence: Sentence): def _get_state_dict(self): model_state = { - "state_dict": self.state_dict(), + **super()._get_state_dict(), "word_embeddings": self.word_embeddings, "max_distance": self.max_distance, "beta": self.beta, @@ -156,23 +156,23 @@ def _get_state_dict(self): } return model_state - @staticmethod - def _init_model_with_state_dict(state): + @classmethod + def _init_model_with_state_dict(cls, state, **kwargs): + beta = 1.0 if "beta" not in state.keys() else state["beta"] weight = 1 if "loss_max_weight" not in state.keys() else state["loss_max_weight"] - model = DistancePredictor( + return super()._init_model_with_state_dict( + state, word_embeddings=state["word_embeddings"], max_distance=state["max_distance"], beta=beta, loss_max_weight=weight, regression=state["regression"], regr_loss_step=state["regr_loss_step"], + **kwargs, ) - model.load_state_dict(state["state_dict"]) - return model - # So far only one sentence allowed # If list of sentences is handed the function works with the first sentence of the list def forward_loss(self, data_points: Union[List[Sentence], Sentence]) -> torch.Tensor: diff --git a/flair/models/entity_linker_model.py b/flair/models/entity_linker_model.py index da1023ebfd..2e2fc8b226 100644 --- a/flair/models/entity_linker_model.py +++ b/flair/models/entity_linker_model.py @@ -3,7 +3,6 @@ from typing import List, Optional, Union import torch -import torch.nn as nn import flair.embeddings import flair.nn @@ -40,7 +39,13 @@ def __init__( :param label_type: name of the label you use. """ - super(EntityLinker, self).__init__(label_dictionary, **classifierargs) + super(EntityLinker, self).__init__( + label_dictionary=label_dictionary, + final_embedding_size=word_embeddings.embedding_length * 2 + if pooling_operation == "first&last" + else word_embeddings.embedding_length, + **classifierargs, + ) self.word_embeddings = word_embeddings self.pooling_operation = pooling_operation @@ -55,16 +60,6 @@ def __init__( if dropout > 0.0: self.dropout = torch.nn.Dropout(dropout) - # if we concatenate the embeddings we need double input size in our linear layer - if self.pooling_operation == "first&last": - self.decoder = nn.Linear(2 * self.word_embeddings.embedding_length, len(self.label_dictionary)).to( - flair.device - ) - else: - self.decoder = nn.Linear(self.word_embeddings.embedding_length, len(self.label_dictionary)).to(flair.device) - - nn.init.xavier_uniform_(self.decoder.weight) - cases = { "average": self.emb_mean, "first": self.emb_first, @@ -110,13 +105,10 @@ def forward_pass( span_labels = [] sentences_to_spans = [] empty_label_candidates = [] + embedded_entity_pairs = None - # if the entire batch has no sentence with candidates, return empty - if len(filtered_sentences) == 0: - scores = None - - # otherwise, embed sentence and send through prediction head - else: + # embed sentences and send through prediction head + if len(filtered_sentences) > 0: # embed all tokens self.word_embeddings.embed(filtered_sentences) @@ -152,23 +144,19 @@ def forward_pass( empty_label_candidates.append(candidate) if len(embedding_list) > 0: - embedding_tensor = torch.cat(embedding_list, 0).to(flair.device) + embedded_entity_pairs = torch.cat(embedding_list, 0) if self.use_dropout: - embedding_tensor = self.dropout(embedding_tensor) - - scores = self.decoder(embedding_tensor) - else: - scores = None + embedded_entity_pairs = self.dropout(embedded_entity_pairs) if return_label_candidates: - return scores, span_labels, sentences_to_spans, empty_label_candidates + return embedded_entity_pairs, span_labels, sentences_to_spans, empty_label_candidates - return scores, span_labels + return embedded_entity_pairs, span_labels def _get_state_dict(self): model_state = { - "state_dict": self.state_dict(), + **super()._get_state_dict(), "word_embeddings": self.word_embeddings, "label_type": self.label_type, "label_dictionary": self.label_dictionary, @@ -177,19 +165,18 @@ def _get_state_dict(self): } return model_state - @staticmethod - def _init_model_with_state_dict(state): - model = EntityLinker( + @classmethod + def _init_model_with_state_dict(cls, state, **kwargs): + return super()._init_model_with_state_dict( + state, word_embeddings=state["word_embeddings"], label_dictionary=state["label_dictionary"], label_type=state["label_type"], pooling_operation=state["pooling_operation"], loss_weights=state["loss_weights"] if "loss_weights" in state else {"": 0.3}, + **kwargs, ) - model.load_state_dict(state["state_dict"]) - return model - @property def label_type(self): return self._label_type diff --git a/flair/models/lemmatizer_model.py b/flair/models/lemmatizer_model.py index 1ce2ecd039..e375ab3930 100644 --- a/flair/models/lemmatizer_model.py +++ b/flair/models/lemmatizer_model.py @@ -641,7 +641,7 @@ def predict( def _get_state_dict(self): model_state = { - "state_dict": self.state_dict(), + **super()._get_state_dict(), "embeddings": self.encoder_embeddings, "rnn_input_size": self.rnn_input_size, "rnn_hidden_size": self.rnn_hidden_size, @@ -660,8 +660,10 @@ def _get_state_dict(self): return model_state - def _init_model_with_state_dict(state): - model = Lemmatizer( + @classmethod + def _init_model_with_state_dict(cls, state, **kwargs): + return super()._init_model_with_state_dict( + state, embeddings=state["embeddings"], encode_characters=state["encode_characters"], rnn_input_size=state["rnn_input_size"], @@ -676,9 +678,8 @@ def _init_model_with_state_dict(state): start_symbol_for_encoding=state["start_symbol"], end_symbol_for_encoding=state["end_symbol"], bidirectional_encoding=state["bidirectional_encoding"], + **kwargs, ) - model.load_state_dict(state["state_dict"]) - return model def _print_predictions(self, batch, gold_label_type): lines = [] diff --git a/flair/models/pairwise_classification_model.py b/flair/models/pairwise_classification_model.py index cfe54b81ad..e51f3cc32d 100644 --- a/flair/models/pairwise_classification_model.py +++ b/flair/models/pairwise_classification_model.py @@ -33,7 +33,12 @@ def __init__( :param loss_weights: Dictionary of weights for labels for the loss function (if any label's weight is unspecified it will default to 1.0) """ - super().__init__(**classifierargs) + super().__init__( + **classifierargs, + final_embedding_size=2 * document_embeddings.embedding_length + if embed_separately + else document_embeddings.embedding_length, + ) self.document_embeddings: flair.embeddings.DocumentEmbeddings = document_embeddings @@ -41,20 +46,7 @@ def __init__( self.embed_separately = embed_separately - # if embed_separately == True the linear layer needs twice the length of the embeddings as input size - # since we concatenate the embeddings of the two DataPoints in the DataPairs - if self.embed_separately: - self.decoder = torch.nn.Linear( - 2 * self.document_embeddings.embedding_length, - len(self.label_dictionary), - ).to(flair.device) - - torch.nn.init.xavier_uniform_(self.decoder.weight) - - else: - # representation for both sentences - self.decoder = torch.nn.Linear(self.document_embeddings.embedding_length, len(self.label_dictionary)) - + if not self.embed_separately: # set separator to concatenate two sentences self.sep = " " if isinstance( @@ -66,8 +58,6 @@ def __init__( else: self.sep = " [SEP] " - torch.nn.init.xavier_uniform_(self.decoder.weight) - # auto-spawn on GPU if available self.to(flair.device) @@ -136,7 +126,7 @@ def forward_pass( def _get_state_dict(self): model_state = { - "state_dict": self.state_dict(), + **super()._get_state_dict(), "document_embeddings": self.document_embeddings, "label_dictionary": self.label_dictionary, "label_type": self.label_type, @@ -147,10 +137,10 @@ def _get_state_dict(self): } return model_state - @staticmethod - def _init_model_with_state_dict(state): - - model = TextPairClassifier( + @classmethod + def _init_model_with_state_dict(cls, state, **kwargs): + return super()._init_model_with_state_dict( + state, document_embeddings=state["document_embeddings"], label_dictionary=state["label_dictionary"], label_type=state["label_type"], @@ -160,6 +150,5 @@ def _init_model_with_state_dict(state): else state["multi_label_threshold"], loss_weights=state["weight_dict"], embed_separately=state["embed_separately"], + **kwargs, ) - model.load_state_dict(state["state_dict"]) - return model diff --git a/flair/models/relation_extractor_model.py b/flair/models/relation_extractor_model.py index 8a5e65aca4..76714be694 100644 --- a/flair/models/relation_extractor_model.py +++ b/flair/models/relation_extractor_model.py @@ -3,7 +3,6 @@ from typing import List, Optional, Set, Tuple, Union import torch -import torch.nn as nn import flair.embeddings import flair.nn @@ -25,7 +24,6 @@ def __init__( dropout_value: float = 0.0, locked_dropout_value: float = 0.1, word_dropout_value: float = 0.0, - non_linear_decoder: Optional[int] = 2048, **classifierargs, ): """ @@ -36,7 +34,14 @@ def __init__( :param loss_weights: Dictionary of weights for labels for the loss function (if any label's weight is unspecified it will default to 1.0) """ - super(RelationExtractor, self).__init__(**classifierargs) + + # pooling operation to get embeddings for entites + self.pooling_operation = pooling_operation + relation_representation_length = 2 * embeddings.embedding_length + if self.pooling_operation == "first_last": + relation_representation_length *= 2 + + super(RelationExtractor, self).__init__(**classifierargs, final_embedding_size=relation_representation_length) # set embeddings self.embeddings: flair.embeddings.TokenEmbeddings = embeddings @@ -60,29 +65,6 @@ def __init__( self.word_dropout_value = word_dropout_value self.word_dropout = flair.nn.WordDropout(word_dropout_value) - # pooling operation to get embeddings for entites - self.pooling_operation = pooling_operation - relation_representation_length = 2 * embeddings.embedding_length - if self.pooling_operation == "first_last": - relation_representation_length *= 2 - if type(self.embeddings) == flair.embeddings.TransformerDocumentEmbeddings: - relation_representation_length = embeddings.embedding_length - - # entity pairs could also be no relation at all, add default value for this case to dictionary - self.label_dictionary.add_item("O") - - # decoder can be linear or nonlinear - self.non_linear_decoder = non_linear_decoder - if non_linear_decoder is not None: - self.decoder_1 = nn.Linear(relation_representation_length, non_linear_decoder) - self.nonlinearity = torch.nn.ReLU() - self.decoder_2 = nn.Linear(non_linear_decoder, len(self.label_dictionary)) - nn.init.xavier_uniform_(self.decoder_1.weight) - nn.init.xavier_uniform_(self.decoder_2.weight) - else: - self.decoder = nn.Linear(relation_representation_length, len(self.label_dictionary)) - nn.init.xavier_uniform_(self.decoder.weight) - self.to(flair.device) def add_entity_markers(self, sentence, span_1, span_2): @@ -215,46 +197,35 @@ def forward_pass( ] ) else: - embedding = torch.cat( - [ - span_1.tokens[0].get_embedding(), - span_2.tokens[0].get_embedding(), - ] - ) + embedding = torch.cat([span_1.tokens[0].get_embedding(), span_2.tokens[0].get_embedding()]) relation_embeddings.append(embedding) # stack and drop out (squeeze and unsqueeze) - all_relations = torch.stack(relation_embeddings).unsqueeze(1) - - all_relations = self.dropout(all_relations) - all_relations = self.locked_dropout(all_relations) - all_relations = self.word_dropout(all_relations) + embedded_entity_pairs = torch.stack(relation_embeddings).unsqueeze(1) - all_relations = all_relations.squeeze(1) + embedded_entity_pairs = self.dropout(embedded_entity_pairs) + embedded_entity_pairs = self.locked_dropout(embedded_entity_pairs) + embedded_entity_pairs = self.word_dropout(embedded_entity_pairs) - # send through decoder - if self.non_linear_decoder: - sentence_relation_scores = self.decoder_2(self.nonlinearity(self.decoder_1(all_relations))) - else: - sentence_relation_scores = self.decoder(all_relations) + embedded_entity_pairs = embedded_entity_pairs.squeeze(1) else: - sentence_relation_scores = None + embedded_entity_pairs = None if return_label_candidates: return ( - sentence_relation_scores, + embedded_entity_pairs, labels, sentences_to_label, empty_label_candidates, ) - return sentence_relation_scores, labels + return embedded_entity_pairs, labels def _get_state_dict(self): model_state = { - "state_dict": self.state_dict(), + **super()._get_state_dict(), "embeddings": self.embeddings, "label_dictionary": self.label_dictionary, "label_type": self.label_type, @@ -265,13 +236,14 @@ def _get_state_dict(self): "locked_dropout_value": self.locked_dropout_value, "word_dropout_value": self.word_dropout_value, "entity_pair_filters": self.entity_pair_filters, - "non_linear_decoder": self.non_linear_decoder, } return model_state - @staticmethod - def _init_model_with_state_dict(state): - model = RelationExtractor( + @classmethod + def _init_model_with_state_dict(cls, state, **kwargs): + + return super()._init_model_with_state_dict( + state, embeddings=state["embeddings"], label_dictionary=state["label_dictionary"], label_type=state["label_type"], @@ -282,10 +254,8 @@ def _init_model_with_state_dict(state): locked_dropout_value=state["locked_dropout_value"], word_dropout_value=state["word_dropout_value"], entity_pair_filters=state["entity_pair_filters"], - non_linear_decoder=state["non_linear_decoder"], + **kwargs, ) - model.load_state_dict(state["state_dict"]) - return model @property def label_type(self): diff --git a/flair/models/sequence_tagger_model.py b/flair/models/sequence_tagger_model.py index a182c46a78..ab5b15b41c 100644 --- a/flair/models/sequence_tagger_model.py +++ b/flair/models/sequence_tagger_model.py @@ -582,7 +582,7 @@ def _all_scores_for_token(self, scores: torch.Tensor, lengths: List[int]): def _get_state_dict(self): """Returns the state dictionary for this model.""" model_state = { - "state_dict": self.state_dict(), + **super()._get_state_dict(), "embeddings": self.embeddings, "hidden_size": self.hidden_size, "tag_dictionary": self.label_dictionary, @@ -600,8 +600,9 @@ def _get_state_dict(self): return model_state - @staticmethod - def _init_model_with_state_dict(state): + @classmethod + def _init_model_with_state_dict(cls, state, **kwargs): + """Initialize the model from a state dictionary.""" rnn_type = "LSTM" if "rnn_type" not in state.keys() else state["rnn_type"] use_dropout = 0.0 if "use_dropout" not in state.keys() else state["use_dropout"] @@ -615,7 +616,8 @@ def _init_model_with_state_dict(state): state["state_dict"]["crf.transitions"] = state["state_dict"]["transitions"] del state["state_dict"]["transitions"] - model = SequenceTagger( + return super()._init_model_with_state_dict( + state, embeddings=state["embeddings"], tag_dictionary=state["tag_dictionary"], tag_type=state["tag_type"], @@ -630,11 +632,9 @@ def _init_model_with_state_dict(state): reproject_embeddings=reproject_embeddings, loss_weights=weights, init_from_state_dict=True, + **kwargs, ) - model.load_state_dict(state["state_dict"]) - return model - @staticmethod def _fetch_model(model_name) -> str: diff --git a/flair/models/similarity_learning_model.py b/flair/models/similarity_learning_model.py index 0ebd277948..3d9be5fd7c 100644 --- a/flair/models/similarity_learning_model.py +++ b/flair/models/similarity_learning_model.py @@ -332,7 +332,7 @@ def evaluate( def _get_state_dict(self): model_state = { - "state_dict": self.state_dict(), + **super()._get_state_dict(), "input_modality_0_embedding": self.source_embeddings, "input_modality_1_embedding": self.target_embeddings, "similarity_measure": self.similarity_measure, @@ -345,13 +345,16 @@ def _get_state_dict(self): } return model_state - @staticmethod - def _init_model_with_state_dict(state): + @classmethod + def _init_model_with_state_dict(cls, state, **kwargs): + # The conversion from old model's constructor interface if "input_embeddings" in state: state["input_modality_0_embedding"] = state["input_embeddings"][0] state["input_modality_1_embedding"] = state["input_embeddings"][1] - model = SimilarityLearner( + + return super()._init_model_with_state_dict( + state, source_embeddings=state["input_modality_0_embedding"], target_embeddings=state["input_modality_1_embedding"], source_mapping=state["source_mapping"], @@ -361,7 +364,5 @@ def _init_model_with_state_dict(state): eval_device=state["eval_device"], recall_at_points=state["recall_at_points"], recall_at_points_weights=state["recall_at_points_weights"], + **kwargs, ) - - model.load_state_dict(state["state_dict"]) - return model diff --git a/flair/models/tars_model.py b/flair/models/tars_model.py index 790bdc3ea3..95d6a6afea 100644 --- a/flair/models/tars_model.py +++ b/flair/models/tars_model.py @@ -410,7 +410,7 @@ def _get_tars_formatted_sentence(self, label, sentence): def _get_state_dict(self): model_state = { - "state_dict": self.state_dict(), + **super()._get_state_dict(), "current_task": self._current_task, "tag_type": self.get_current_label_type(), "tag_dictionary": self.get_current_label_dictionary(), @@ -433,23 +433,22 @@ def _fetch_model(model_name) -> str: return model_name - @staticmethod - def _init_model_with_state_dict(state): - + @classmethod + def _init_model_with_state_dict(cls, state, **kwargs): # init new TARS classifier - model = TARSTagger( + model = super()._init_model_with_state_dict( + state, task_name=state["current_task"], label_dictionary=state["tag_dictionary"], label_type=state["tag_type"], embeddings=state["tars_model"].embeddings, num_negative_labels_to_sample=state["num_negative_labels_to_sample"], prefix=state["prefix"], + **kwargs, ) # set all task information model._task_specific_attributes = state["task_specific_attributes"] - # linear layers of internal classifier - model.load_state_dict(state["state_dict"]) return model @property @@ -685,7 +684,7 @@ def _get_tars_formatted_sentence(self, label, sentence): def _get_state_dict(self): model_state = { - "state_dict": self.state_dict(), + **super()._get_state_dict(), "current_task": self._current_task, "label_type": self.get_current_label_type(), "label_dictionary": self.get_current_label_dictionary(), @@ -695,26 +694,25 @@ def _get_state_dict(self): } return model_state - @staticmethod - def _init_model_with_state_dict(state): - + @classmethod + def _init_model_with_state_dict(cls, state, **kwargs): # init new TARS classifier label_dictionary = state["label_dictionary"] label_type = "default_label" if not state["label_type"] else state["label_type"] - model: TARSClassifier = TARSClassifier( + model: TARSClassifier = super()._init_model_with_state_dict( + state, task_name=state["current_task"], label_dictionary=label_dictionary, label_type=label_type, embeddings=state["tars_model"].document_embeddings, num_negative_labels_to_sample=state["num_negative_labels_to_sample"], + **kwargs, ) # set all task information model._task_specific_attributes = state["task_specific_attributes"] - # linear layers of internal classifier - model.load_state_dict(state["state_dict"]) return model @staticmethod diff --git a/flair/models/text_classification_model.py b/flair/models/text_classification_model.py index 57e7ed2ebf..daa910da4c 100644 --- a/flair/models/text_classification_model.py +++ b/flair/models/text_classification_model.py @@ -3,7 +3,6 @@ from typing import List, Tuple, Union import torch -import torch.nn as nn import flair.embeddings import flair.nn @@ -40,15 +39,14 @@ def __init__( (if any label's weight is unspecified it will default to 1.0) """ - super(TextClassifier, self).__init__(**classifierargs) + super(TextClassifier, self).__init__( + **classifierargs, final_embedding_size=document_embeddings.embedding_length + ) self.document_embeddings: flair.embeddings.DocumentEmbeddings = document_embeddings self._label_type = label_type - self.decoder = nn.Linear(self.document_embeddings.embedding_length, len(self.label_dictionary)) - nn.init.xavier_uniform_(self.decoder.weight) - # auto-spawn on GPU if available self.to(flair.device) @@ -71,22 +69,19 @@ def forward_pass( text_embedding_list = [sentence.get_embedding(embedding_names).unsqueeze(0) for sentence in sentences] text_embedding_tensor = torch.cat(text_embedding_list, 0).to(flair.device) - # send through decoder to get logits - scores = self.decoder(text_embedding_tensor) - labels = [] for sentence in sentences: labels.append([label.value for label in sentence.get_labels(self.label_type)]) if return_label_candidates: label_candidates = [Label(value="") for sentence in sentences] - return scores, labels, sentences, label_candidates + return text_embedding_tensor, labels, sentences, label_candidates - return scores, labels + return text_embedding_tensor, labels def _get_state_dict(self): model_state = { - "state_dict": self.state_dict(), + **super()._get_state_dict(), "document_embeddings": self.document_embeddings, "label_dictionary": self.label_dictionary, "label_type": self.label_type, @@ -96,12 +91,14 @@ def _get_state_dict(self): } return model_state - @staticmethod - def _init_model_with_state_dict(state): + @classmethod + def _init_model_with_state_dict(cls, state, **kwargs): + weights = None if "weight_dict" not in state.keys() else state["weight_dict"] label_type = None if "label_type" not in state.keys() else state["label_type"] - model = TextClassifier( + return super()._init_model_with_state_dict( + state, document_embeddings=state["document_embeddings"], label_dictionary=state["label_dictionary"], label_type=label_type, @@ -110,9 +107,8 @@ def _init_model_with_state_dict(state): if "multi_label_threshold" not in state.keys() else state["multi_label_threshold"], loss_weights=weights, + **kwargs, ) - model.load_state_dict(state["state_dict"]) - return model @staticmethod def _fetch_model(model_name) -> str: diff --git a/flair/models/text_regression_model.py b/flair/models/text_regression_model.py index 53fbdb37ac..84c1ef1225 100644 --- a/flair/models/text_regression_model.py +++ b/flair/models/text_regression_model.py @@ -219,21 +219,19 @@ def evaluate( def _get_state_dict(self): model_state = { - "state_dict": self.state_dict(), + **super()._get_state_dict(), "document_embeddings": self.document_embeddings, "label_name": self.label_type, } return model_state - @staticmethod - def _init_model_with_state_dict(state): - + @classmethod + def _init_model_with_state_dict(cls, state, **kwargs): label_name = state["label_name"] if "label_name" in state.keys() else None - model = TextRegressor(document_embeddings=state["document_embeddings"], label_name=label_name) - - model.load_state_dict(state["state_dict"]) - return model + return super()._init_model_with_state_dict( + state, document_embeddings=state["document_embeddings"], label_name=label_name, **kwargs + ) @staticmethod def _filter_empty_sentences(sentences: List[Sentence]) -> List[Sentence]: diff --git a/flair/models/word_tagger_model.py b/flair/models/word_tagger_model.py index 9b8ccb74f1..996622e400 100644 --- a/flair/models/word_tagger_model.py +++ b/flair/models/word_tagger_model.py @@ -30,39 +30,37 @@ def __init__( :param tag_type: string identifier for tag type :param beta: Parameter for F-beta score for evaluation and training annealing """ - super().__init__(label_dictionary=tag_dictionary, **classifierargs) + super().__init__( + label_dictionary=tag_dictionary, final_embedding_size=embeddings.embedding_length, **classifierargs + ) # embeddings self.embeddings = embeddings # dictionaries self.tag_type: str = tag_type - self.tagset_size: int = len(tag_dictionary) - - # linear layer - self.linear = torch.nn.Linear(self.embeddings.embedding_length, len(tag_dictionary)) # all parameters will be pushed internally to the specified device self.to(flair.device) def _get_state_dict(self): model_state = { - "state_dict": self.state_dict(), + **super()._get_state_dict(), "embeddings": self.embeddings, "tag_dictionary": self.label_dictionary, "tag_type": self.tag_type, } return model_state - @staticmethod - def _init_model_with_state_dict(state): - model = WordTagger( + @classmethod + def _init_model_with_state_dict(cls, state, **kwargs): + return super()._init_model_with_state_dict( + state, embeddings=state["embeddings"], tag_dictionary=state["tag_dictionary"], tag_type=state["tag_type"], + **kwargs, ) - model.load_state_dict(state["state_dict"]) - return model def forward_pass( self, @@ -81,17 +79,15 @@ def forward_pass( all_embeddings = [token.get_embedding(names) for token in all_tokens] - embedding_tensor = torch.stack(all_embeddings) - - scores = self.linear(embedding_tensor) + embedded_tokens = torch.stack(all_embeddings) labels = [[token.get_tag(self.label_type).value] for token in all_tokens] if return_label_candidates: empty_label_candidates = [Label(value=None, score=0.0) for token in all_tokens] - return scores, labels, all_tokens, empty_label_candidates + return embedded_tokens, labels, all_tokens, empty_label_candidates - return scores, labels + return embedded_tokens, labels @property def label_type(self): diff --git a/flair/nn/__init__.py b/flair/nn/__init__.py index 2645611f06..0872a3bafe 100644 --- a/flair/nn/__init__.py +++ b/flair/nn/__init__.py @@ -1,4 +1,5 @@ +from .decoder import PrototypicalDecoder from .dropout import LockedDropout, WordDropout from .model import Classifier, DefaultClassifier, Model -__all__ = ["LockedDropout", "WordDropout", "Classifier", "DefaultClassifier", "Model"] +__all__ = ["LockedDropout", "WordDropout", "Classifier", "DefaultClassifier", "Model", "PrototypicalDecoder"] diff --git a/flair/nn/decoder.py b/flair/nn/decoder.py new file mode 100644 index 0000000000..7bcfc035ea --- /dev/null +++ b/flair/nn/decoder.py @@ -0,0 +1,214 @@ +import logging +from collections import Counter +from typing import List, Optional + +import torch +from tqdm import tqdm + +import flair +from flair.data import FlairDataset +from flair.datasets import DataLoader +from flair.nn.distance import ( + CosineDistance, + EuclideanDistance, + HyperbolicDistance, + LogitCosineDistance, + NegativeScaledDotProduct, +) +from flair.nn.model import DefaultClassifier +from flair.training_utils import store_embeddings + +logger = logging.getLogger("flair") + + +class PrototypicalDecoder(torch.nn.Module): + def __init__( + self, + num_prototypes: int, + embeddings_size: int, + prototype_size: Optional[int] = None, + distance_function: str = "euclidean", + use_radius: Optional[bool] = False, + min_radius: Optional[int] = 0, + unlabeled_distance: Optional[float] = None, + unlabeled_idx: Optional[int] = None, + learning_mode: Optional[str] = "joint", + normal_distributed_initial_prototypes: bool = False, + ): + + super().__init__() + + if not prototype_size: + prototype_size = embeddings_size + + self.prototype_size = prototype_size + + # optional metric space decoder if prototypes have different length than embedding + self.metric_space_decoder: Optional[torch.nn.Linear] = None + if prototype_size != embeddings_size: + self.metric_space_decoder = torch.nn.Linear(embeddings_size, prototype_size) + torch.nn.init.xavier_uniform_(self.metric_space_decoder.weight) + + # create initial prototypes for all classes (all initial prototypes are a vector of all 1s) + self.prototype_vectors = torch.nn.Parameter(torch.ones(num_prototypes, prototype_size), requires_grad=True) + + # if set, create initial prototypes from normal distribution + if normal_distributed_initial_prototypes: + self.prototype_vectors = torch.nn.Parameter(torch.normal(torch.zeros(num_prototypes, prototype_size))) + + # if set, use a radius + self.prototype_radii: Optional[torch.nn.Parameter] = None + if use_radius: + self.prototype_radii = torch.nn.Parameter(torch.ones(num_prototypes), requires_grad=True) + + self.min_radius = min_radius + self.learning_mode = learning_mode + + assert (unlabeled_idx is None) == ( + unlabeled_distance is None + ), "'unlabeled_idx' and 'unlabeled_distance' should either both be set or both not be set." + + self.unlabeled_idx = unlabeled_idx + self.unlabeled_distance = unlabeled_distance + + self._distance_function = distance_function + + self.distance: Optional[torch.nn.Module] = None + if distance_function.lower() == "hyperbolic": + self.distance = HyperbolicDistance() + elif distance_function.lower() == "cosine": + self.distance = CosineDistance() + elif distance_function.lower() == "logit_cosine": + self.distance = LogitCosineDistance() + elif distance_function.lower() == "euclidean": + self.distance = EuclideanDistance() + elif distance_function.lower() == "dot_product": + self.distance = NegativeScaledDotProduct() + else: + raise KeyError(f"Distance function {distance_function} not found.") + + # all parameters will be pushed internally to the specified device + self.to(flair.device) + + @property + def num_prototypes(self): + return self.prototype_vectors.size(0) + + def forward(self, embedded): + if self.learning_mode == "learn_only_map_and_prototypes": + embedded = embedded.detach() + + # decode embeddings into prototype space + if self.metric_space_decoder is not None: + encoded = self.metric_space_decoder(embedded) + else: + encoded = embedded + + prot = self.prototype_vectors + radii = self.prototype_radii + + if self.learning_mode == "learn_only_prototypes": + encoded = encoded.detach() + + if self.learning_mode == "learn_only_embeddings_and_map": + prot = prot.detach() + + if radii is not None: + radii = radii.detach() + + distance = self.distance(encoded, prot) + + if radii is not None: + distance /= self.min_radius + torch.nn.functional.softplus(radii) + + # if unlabeled distance is set, mask out loss to unlabeled class prototype + if self.unlabeled_distance: + distance[..., self.unlabeled_idx] = self.unlabeled_distance + + scores = -distance + + return scores + + def enable_expectation_maximization( + self, + data: FlairDataset, + encoder: DefaultClassifier, + exempt_labels: List[str] = [], + mini_batch_size: int = 8, + ): + """Applies monkey-patch to train method (which sets the train flag). + + This allows for computation of average prototypes after a training + sequence.""" + + decoder = self + + unpatched_train = encoder.train + + def patched_train(mode: bool = True): + unpatched_train(mode=mode) + if mode: + logger.info("recalculating prototypes") + with torch.no_grad(): + decoder.calculate_prototypes( + data=data, encoder=encoder, exempt_labels=exempt_labels, mini_batch_size=mini_batch_size + ) + + # Monkey-patching is problematic for mypy (https://github.com/python/mypy/issues/2427) + encoder.train = patched_train # type: ignore + + def calculate_prototypes( + self, + data: FlairDataset, + encoder: DefaultClassifier, + exempt_labels: List[str] = [], + mini_batch_size=32, + ): + """ + Function that calclues a prototype for each class based on the euclidean average embedding over the whole dataset + :param data: dataset for which to calculate prototypes + :param encoder: encoder to use + :param exempt_labels: labels to exclude + :param mini_batch_size: number of sentences to embed at same time + :return: + """ + + # gradients are not required for prototype computation + with torch.no_grad(): + + dataloader = DataLoader(data, batch_size=mini_batch_size) + + # reset prototypes for all classes + new_prototypes = torch.zeros(self.num_prototypes, self.prototype_size, device=flair.device) + + counter: Counter = Counter() + + for batch in tqdm(dataloader): + + logits, labels = encoder.forward_pass(batch) # type: ignore + + if len(labels) > 0: + # decode embeddings into prototype space + if self.metric_space_decoder is not None: + logits = self.metric_space_decoder(logits) + + for logit, label in zip(logits, labels): + counter.update(label) + + idx = encoder.label_dictionary.get_idx_for_item(label[0]) + + new_prototypes[idx] += logit + + # embeddings need to be removed so that memory doesn't fill up + store_embeddings(batch, storage_mode="none") + + # TODO: changes required + for label, count in counter.most_common(): + average_prototype = new_prototypes[encoder.label_dictionary.get_idx_for_item(label)] / count + new_prototypes[encoder.label_dictionary.get_idx_for_item(label)] = average_prototype + + for label in exempt_labels: + label_idx = encoder.label_dictionary.get_idx_for_item(label) + new_prototypes[label_idx] = self.prototype_vectors[label_idx] + + self.prototype_vectors.data = new_prototypes.to(flair.device) diff --git a/flair/nn/distance/__init__.py b/flair/nn/distance/__init__.py new file mode 100644 index 0000000000..b993ddb61c --- /dev/null +++ b/flair/nn/distance/__init__.py @@ -0,0 +1,13 @@ +from .cosine import CosineDistance, LogitCosineDistance, NegativeScaledDotProduct +from .euclidean import EuclideanDistance, EuclideanMean +from .hyperbolic import HyperbolicDistance, HyperbolicMean + +__all__ = [ + "EuclideanDistance", + "EuclideanMean", + "HyperbolicDistance", + "HyperbolicMean", + "CosineDistance", + "LogitCosineDistance", + "NegativeScaledDotProduct", +] diff --git a/flair/nn/distance/cosine.py b/flair/nn/distance/cosine.py new file mode 100644 index 0000000000..d7da1ccc8c --- /dev/null +++ b/flair/nn/distance/cosine.py @@ -0,0 +1,38 @@ +import torch + +# Source: https://github.com/UKPLab/sentence-transformers/blob/master/sentence_transformers/util.py#L23 + + +def dot_product(a: torch.Tensor, b: torch.Tensor, normalize=False): + """ + Computes dot product for pairs of vectors. + :param normalize: Vectors are normalized (leads to cosine similarity) + :return: Matrix with res[i][j] = dot_product(a[i], b[j]) + """ + if len(a.shape) == 1: + a = a.unsqueeze(0) + + if len(b.shape) == 1: + b = b.unsqueeze(0) + + if normalize: + a = torch.nn.functional.normalize(a, p=2, dim=1) + b = torch.nn.functional.normalize(b, p=2, dim=1) + + return torch.mm(a, b.transpose(0, 1)) + + +class CosineDistance(torch.nn.Module): + def forward(self, a, b): + return -dot_product(a, b, normalize=True) + + +class LogitCosineDistance(torch.nn.Module): + def forward(self, a, b): + return torch.logit(0.5 - 0.5 * dot_product(a, b, normalize=True)) + + +class NegativeScaledDotProduct(torch.nn.Module): + def forward(self, a, b): + sqrt_d = torch.sqrt(torch.tensor(a.size(-1))) + return -dot_product(a, b, normalize=False) / sqrt_d diff --git a/flair/nn/distance/euclidean.py b/flair/nn/distance/euclidean.py new file mode 100644 index 0000000000..05d842d1eb --- /dev/null +++ b/flair/nn/distance/euclidean.py @@ -0,0 +1,67 @@ +""" +This module was copied from the repository the following repository: +https://github.com/asappresearch/dynamic-classification + +It contains the code from the paper "Metric Learning for Dynamic Text +Classification". + +https://arxiv.org/abs/1911.01026 + +In case this file is modified, please consider contributing to the original +repository. + +It was published under MIT License: +https://github.com/asappresearch/dynamic-classification/blob/master/LICENSE.md + +Source: https://github.com/asappresearch/dynamic-classification/blob/55beb5a48406c187674bea40487c011e8fa45aab/distance/euclidean.py +""" + + +import torch +import torch.nn as nn +from torch import Tensor + + +class EuclideanDistance(nn.Module): + """Implement a EuclideanDistance object.""" + + def forward(self, mat_1: Tensor, mat_2: Tensor) -> Tensor: # type: ignore + """Returns the squared euclidean distance between each + element in mat_1 and each element in mat_2. + + Parameters + ---------- + mat_1: torch.Tensor + matrix of shape (n_1, n_features) + mat_2: torch.Tensor + matrix of shape (n_2, n_features) + + Returns + ------- + dist: torch.Tensor + distance matrix of shape (n_1, n_2) + + """ + _dist = [torch.sum((mat_1 - mat_2[i]) ** 2, dim=1) for i in range(mat_2.size(0))] + dist = torch.stack(_dist, dim=1) + return dist + + +class EuclideanMean(nn.Module): + """Implement a EuclideanMean object.""" + + def forward(self, data: Tensor) -> Tensor: # type: ignore + """Performs a forward pass through the network. + + Parameters + ---------- + data : torch.Tensor + The input data, as a float tensor + + Returns + ------- + torch.Tensor + The encoded output, as a float tensor + + """ + return data.mean(0) diff --git a/flair/nn/distance/hyperbolic.py b/flair/nn/distance/hyperbolic.py new file mode 100644 index 0000000000..4f6ee351b7 --- /dev/null +++ b/flair/nn/distance/hyperbolic.py @@ -0,0 +1,138 @@ +""" +This module was copied from the repository the following repository: +https://github.com/asappresearch/dynamic-classification + +It contains the code from the paper "Metric Learning for Dynamic Text +Classification". + +https://arxiv.org/abs/1911.01026 + +In case this file is modified, please consider contributing to the original +repository. + +It was published under MIT License: +https://github.com/asappresearch/dynamic-classification/blob/master/LICENSE.md + +Source: https://github.com/asappresearch/dynamic-classification/blob/55beb5a48406c187674bea40487c011e8fa45aab/distance/hyperbolic.py +""" + +import torch +import torch.nn as nn +from torch import Tensor + +EPSILON = 1e-5 + + +def arccosh(x): + """Compute the arcosh, numerically stable.""" + x = torch.clamp(x, min=1 + EPSILON) + a = torch.log(x) + b = torch.log1p(torch.sqrt(x * x - 1) / x) + return a + b + + +def mdot(x, y): + """Compute the inner product.""" + m = x.new_ones(1, x.size(1)) + m[0, 0] = -1 + return torch.sum(m * x * y, 1, keepdim=True) + + +def dist(x, y): + """Get the hyperbolic distance between x and y.""" + return arccosh(-mdot(x, y)) + + +def project(x): + """Project onto the hyeprboloid embedded in in n+1 dimensions.""" + return torch.cat([torch.sqrt(1.0 + torch.sum(x * x, 1, keepdim=True)), x], 1) + + +def log_map(x, y): + """Perform the log step.""" + d = dist(x, y) + return (d / torch.sinh(d)) * (y - torch.cosh(d) * x) + + +def norm(x): + """Compute the norm""" + n = torch.sqrt(torch.abs(mdot(x, x))) + return n + + +def exp_map(x, y): + """Perform the exp step.""" + n = torch.clamp(norm(y), min=EPSILON) + return torch.cosh(n) * x + (torch.sinh(n) / n) * y + + +def loss(x, y): + """Get the loss for the optimizer.""" + return torch.sum(dist(x, y) ** 2) + + +class HyperbolicDistance(nn.Module): + """Implement a HyperbolicDistance object.""" + + def forward(self, mat_1: Tensor, mat_2: Tensor) -> Tensor: # type: ignore + """Returns the squared euclidean distance between each + element in mat_1 and each element in mat_2. + + Parameters + ---------- + mat_1: torch.Tensor + matrix of shape (n_1, n_features) + mat_2: torch.Tensor + matrix of shape (n_2, n_features) + + Returns + ------- + dist: torch.Tensor + distance matrix of shape (n_1, n_2) + + """ + # Get projected 1st dimension + mat_1_x_0 = torch.sqrt(1 + mat_1.pow(2).sum(dim=1, keepdim=True)) + mat_2_x_0 = torch.sqrt(1 + mat_2.pow(2).sum(dim=1, keepdim=True)) + + # Compute bilinear form + left = mat_1_x_0.mm(mat_2_x_0.t()) # n_1 x n_2 + right = mat_1[:, 1:].mm(mat_2[:, 1:].t()) # n_1 x n_2 + + # Arcosh + return arccosh(left - right).pow(2) + + +class HyperbolicMean(nn.Module): + """Compute the mean point in the hyperboloid model.""" + + def forward(self, data: Tensor) -> Tensor: # type: ignore + """Performs a forward pass through the network. + + Parameters + ---------- + data : torch.Tensor + The input data, as a float tensor + + Returns + ------- + torch.Tensor + The encoded output, as a float tensor + + """ + n_iter = 5 if self.training else 100 + + # Project the input data to n+1 dimensions + projected = project(data) + + mean = torch.mean(projected, 0, keepdim=True) + mean = mean / norm(mean) + + r = 1e-2 + for i in range(n_iter): + g = -2 * torch.mean(log_map(mean, projected), 0, keepdim=True) + mean = exp_map(mean, -r * g) + mean = mean / norm(mean) + + # The first dimension, is recomputed in the distance module + return mean.squeeze()[1:] diff --git a/flair/nn/model.py b/flair/nn/model.py index d78a9315dc..47ca214ba5 100644 --- a/flair/nn/model.py +++ b/flair/nn/model.py @@ -62,20 +62,19 @@ def evaluate( """ raise NotImplementedError - @abstractmethod def _get_state_dict(self): - """Returns the state dictionary for this model. - Implementing this enables the save() and save_checkpoint() - functionality.""" - raise NotImplementedError + """Returns the state dictionary for this model.""" + state_dict = {"state_dict": self.state_dict()} - @staticmethod - @abstractmethod - def _init_model_with_state_dict(state): - """Initialize the model from a state dictionary. - Implementing this enables the load() and load_checkpoint() - functionality.""" - raise NotImplementedError + return state_dict + + @classmethod + def _init_model_with_state_dict(cls, state, **kwargs): + """Initialize the model from a state dictionary.""" + model = cls(**kwargs) + + model.load_state_dict(state["state_dict"]) + return model @staticmethod def _fetch_model(model_name) -> str: @@ -490,29 +489,14 @@ class DefaultClassifier(Classifier[DT], typing.Generic[DT]): forward_pass() method to implement this base class. """ - def forward_pass( - self, - sentences: Union[List[DT], DT], - return_label_candidates: bool = False, - ) -> Union[Tuple[torch.Tensor, List[List[str]]], Tuple[torch.Tensor, List[List[str]], List[DT], List[Label]]]: - """This method does a forward pass through the model given a list of data - points as input. - Returns the tuple (scores, labels) if return_label_candidates = False, - where scores are a tensor of logits produced by the decoder and labels - are the string labels for each data point. Returns the tuple (scores, - labels, data_points, candidate_labels) if return_label_candidates = True, - where data_points are the data points to which labels are added (commonly - either Sentence or Token objects) and candidate_labels are empty Label - objects for each prediction (depending on the task Label, SpanLabel or - RelationLabel).""" - raise NotImplementedError - def __init__( self, label_dictionary: Dictionary, + final_embedding_size: int, multi_label: bool = False, multi_label_threshold: float = 0.5, loss_weights: Dict[str, float] = None, + decoder: Optional[torch.nn.Module] = None, ): super().__init__() @@ -520,6 +504,15 @@ def __init__( # initialize the label dictionary self.label_dictionary: Dictionary = label_dictionary + if decoder is not None: + self.decoder = decoder + self._custom_decoder = True + else: + # initialize the decoder + self.decoder = torch.nn.Linear(final_embedding_size, len(self.label_dictionary)) + torch.nn.init.xavier_uniform_(self.decoder.weight) + self._custom_decoder = False + # set up multi-label logic self.multi_label = multi_label self.multi_label_threshold = multi_label_threshold @@ -542,6 +535,23 @@ def __init__( else: self.loss_function = torch.nn.CrossEntropyLoss(weight=self.loss_weights) + def forward_pass( + self, + sentences: Union[List[DT], DT], + return_label_candidates: bool = False, + ) -> Union[Tuple[torch.Tensor, List[List[str]]], Tuple[torch.Tensor, List[List[str]], List[DT], List[Label]]]: + """This method does a forward pass through the model given a list of data + points as input. + Returns the tuple (scores, labels) if return_label_candidates = False, + where scores are a tensor of logits produced by the decoder and labels + are the string labels for each data point. Returns the tuple (scores, + labels, data_points, candidate_labels) if return_label_candidates = True, + where data_points are the data points to which labels are added (commonly + either Sentence or Token objects) and candidate_labels are empty Label + objects for each prediction (depending on the task Label, SpanLabel or + RelationLabel).""" + raise NotImplementedError + @property def multi_label_threshold(self): return self._multi_label_threshold @@ -557,14 +567,22 @@ def multi_label_threshold(self, x): # setter method self._multi_label_threshold = {"default": x} def forward_loss(self, sentences: Union[List[DT], DT]) -> Tuple[torch.Tensor, int]: - scores, labels = self.forward_pass(sentences) # type: ignore - return self._calculate_loss(scores, labels) - def _calculate_loss(self, scores, labels) -> Tuple[torch.Tensor, int]: + # make a forward pass to produce embedded data points and labels + embedded_data_points, labels = self.forward_pass(sentences) # type: ignore + # no loss can be calculated if there are no labels if not any(labels): return torch.tensor(0.0, requires_grad=True, device=flair.device), 1 + # push embedded_data_points through decoder to get the scores + scores = self.decoder(embedded_data_points) + + # calculate the loss + return self._calculate_loss(scores, labels) + + def _calculate_loss(self, scores, labels) -> Tuple[torch.Tensor, int]: + if self.multi_label: labels = torch.tensor( [ @@ -661,19 +679,21 @@ def predict( if not batch: continue - scores, gold_labels, data_points, label_candidates = self.forward_pass( # type: ignore + embedded_data_points, gold_labels, data_points, label_candidates = self.forward_pass( # type: ignore batch, return_label_candidates=True ) - # remove previously predicted labels of this type - for sentence in data_points: - sentence.remove_labels(label_name) - - if return_loss: - overall_loss += self._calculate_loss(scores, gold_labels)[0] - label_count += len(label_candidates) - # if anything could possibly be predicted if len(label_candidates) > 0: + scores = self.decoder(embedded_data_points) + + # remove previously predicted labels of this type + for sentence in data_points: + sentence.remove_labels(label_name) + + if return_loss: + overall_loss += self._calculate_loss(scores, gold_labels)[0] + label_count += len(label_candidates) + if self.multi_label: sigmoided = torch.sigmoid(scores) # size: (n_sentences, n_classes) n_labels = sigmoided.size(1) @@ -727,3 +747,18 @@ def __str__(self): + f" (weights): {self.weight_dict}\n" + f" (weight_tensor) {self.loss_weights}\n)" ) + + @classmethod + def _init_model_with_state_dict(cls, state, **kwargs): + if "decoder" not in kwargs and "decoder" in state: + kwargs["decoder"] = state["decoder"] + + return super(Classifier, cls)._init_model_with_state_dict(state, **kwargs) + + def _get_state_dict(self): + state = super()._get_state_dict() + + if self._custom_decoder: + state["decoder"] = self.decoder + + return state