diff --git a/flair/models/sequence_tagger_model.py b/flair/models/sequence_tagger_model.py index 5861153550..6b89c23915 100644 --- a/flair/models/sequence_tagger_model.py +++ b/flair/models/sequence_tagger_model.py @@ -186,11 +186,7 @@ def save_checkpoint(self, model_file: Path, optimizer_state: dict, scheduler_sta @classmethod def load_from_file(cls, model_file: Union[str, Path]): - # suppress torch warnings: - # https://docs.python.org/3/library/warnings.html#temporarily-suppressing-warnings - with warnings.catch_warnings(): - warnings.filterwarnings("ignore") - state = torch.load(str(model_file), map_location={'cuda:0': 'cpu'}) + state = SequenceTagger._load_state(model_file) use_dropout = 0.0 if not 'use_dropout' in state.keys() else state['use_dropout'] use_word_dropout = 0.0 if not 'use_word_dropout' in state.keys() else state['use_word_dropout'] @@ -208,53 +204,42 @@ def load_from_file(cls, model_file: Union[str, Path]): word_dropout=use_word_dropout, locked_dropout=use_locked_dropout, ) - model.load_state_dict(state['state_dict']) model.eval() + if torch.cuda.is_available(): model = model.cuda() + return model @classmethod def load_checkpoint(cls, model_file: Union[str, Path]): - # suppress torch warnings: - # https://docs.python.org/3/library/warnings.html#temporarily-suppressing-warnings - with warnings.catch_warnings(): - warnings.filterwarnings("ignore") - state = torch.load(str(model_file), map_location={'cuda:0': 'cpu'}) - - use_dropout = 0.0 if not 'use_dropout' in state.keys() else state['use_dropout'] - use_word_dropout = 0.0 if not 'use_word_dropout' in state.keys() else state['use_word_dropout'] - use_locked_dropout = 0.0 if not 'use_locked_dropout' in state.keys() else state['use_locked_dropout'] + state = SequenceTagger._load_state(model_file) + model = SequenceTagger.load_from_file(model_file) epoch = state['epoch'] if 'epoch' in state else None loss = state['loss'] if 'loss' in state else None optimizer_state_dict = state['optimizer_state_dict'] if 'optimizer_state_dict' in state else None scheduler_state_dict = state['scheduler_state_dict'] if 'scheduler_state_dict' in state else None - model = SequenceTagger( - hidden_size=state['hidden_size'], - embeddings=state['embeddings'], - tag_dictionary=state['tag_dictionary'], - tag_type=state['tag_type'], - use_crf=state['use_crf'], - use_rnn=state['use_rnn'], - rnn_layers=state['rnn_layers'], - dropout=use_dropout, - word_dropout=use_word_dropout, - locked_dropout=use_locked_dropout, - ) - - model.load_state_dict(state['state_dict']) - model.eval() - if torch.cuda.is_available(): - model = model.cuda() - return { 'model': model, 'epoch': epoch, 'loss': loss, 'optimizer_state_dict': optimizer_state_dict, 'scheduler_state_dict': scheduler_state_dict } + @classmethod + def _load_state(cls, model_file): + # ATTENTION: suppressing torch serialization warnings. This needs to be taken out once we sort out recursive + # serialization of torch objects + # https://docs.python.org/3/library/warnings.html#temporarily-suppressing-warnings + with warnings.catch_warnings(): + warnings.filterwarnings("ignore") + if torch.cuda.is_available(): + state = torch.load(str(model_file)) + else: + state = torch.load(str(model_file), map_location={'cuda:0': 'cpu'}) + return state + def forward_loss(self, sentences: Union[List[Sentence], Sentence]) -> torch.tensor: features, lengths, tags = self.forward(sentences) return self._calculate_loss(features, lengths, tags) diff --git a/flair/models/text_classification_model.py b/flair/models/text_classification_model.py index d984687aba..1112f9163c 100644 --- a/flair/models/text_classification_model.py +++ b/flair/models/text_classification_model.py @@ -97,64 +97,55 @@ def save_checkpoint(self, model_file: Path, optimizer_state: dict, scheduler_sta torch.save(model_state, str(model_file), pickle_protocol=4) @classmethod - def load_from_file(cls, model_file: Path): + def load_from_file(cls, model_file: [str, Path]): """ Loads the model from the given file. :param model_file: the model file :return: the loaded text classifier model """ - - # ATTENTION: suppressing torch serialization warnings. This needs to be taken out once we sort out recursive - # serialization of torch objects - # https://docs.python.org/3/library/warnings.html#temporarily-suppressing-warnings - with warnings.catch_warnings(): - warnings.filterwarnings("ignore") - if torch.cuda.is_available(): - state = torch.load(str(model_file)) - else: - state = torch.load(str(model_file), map_location={'cuda:0': 'cpu'}) + state = TextClassifier._load_state(model_file) model = TextClassifier( document_embeddings=state['document_embeddings'], label_dictionary=state['label_dictionary'], multi_label=state['multi_label'] ) - model.load_state_dict(state['state_dict']) model.eval() + + if torch.cuda.is_available(): + model = model.cuda() + return model @classmethod - def load_checkpoint(cls, model_file: Path): - # ATTENTION: suppressing torch serialization warnings. This needs to be taken out once we sort out recursive - # serialization of torch objects - # https://docs.python.org/3/library/warnings.html#temporarily-suppressing-warnings - with warnings.catch_warnings(): - warnings.filterwarnings("ignore") - if torch.cuda.is_available(): - state = torch.load(str(model_file)) - else: - state = torch.load(str(model_file), map_location={'cuda:0': 'cpu'}) + def load_checkpoint(cls, model_file: [str, Path]): + state = TextClassifier._load_state(model_file) + model = TextClassifier.load_from_file(model_file) epoch = state['epoch'] if 'epoch' in state else None loss = state['loss'] if 'loss' in state else None optimizer_state_dict = state['optimizer_state_dict'] if 'optimizer_state_dict' in state else None scheduler_state_dict = state['scheduler_state_dict'] if 'scheduler_state_dict' in state else None - model = TextClassifier( - document_embeddings=state['document_embeddings'], - label_dictionary=state['label_dictionary'], - multi_label=state['multi_label'] - ) - - model.load_state_dict(state['state_dict']) - model.eval() - return { 'model': model, 'epoch': epoch, 'loss': loss, 'optimizer_state_dict': optimizer_state_dict, 'scheduler_state_dict': scheduler_state_dict } + @classmethod + def _load_state(cls, model_file): + # ATTENTION: suppressing torch serialization warnings. This needs to be taken out once we sort out recursive + # serialization of torch objects + # https://docs.python.org/3/library/warnings.html#temporarily-suppressing-warnings + with warnings.catch_warnings(): + warnings.filterwarnings("ignore") + if torch.cuda.is_available(): + state = torch.load(str(model_file)) + else: + state = torch.load(str(model_file), map_location={'cuda:0': 'cpu'}) + return state + def forward_loss(self, sentences: Union[List[Sentence], Sentence]) -> torch.tensor: scores = self.forward(sentences) return self._calculate_loss(scores, sentences) diff --git a/flair/trainers/language_model_trainer.py b/flair/trainers/language_model_trainer.py index 6e577b1592..01127bcafd 100644 --- a/flair/trainers/language_model_trainer.py +++ b/flair/trainers/language_model_trainer.py @@ -160,7 +160,7 @@ def __init__(self, optimizer: Optimizer = SGD, test_mode: bool = False, epoch: int = 0, - loss: float = 1, + loss: float = 10000, optimizer_state: dict = None ): self.model: LanguageModel = model diff --git a/flair/trainers/trainer.py b/flair/trainers/trainer.py index 13dc719551..edf62226ba 100644 --- a/flair/trainers/trainer.py +++ b/flair/trainers/trainer.py @@ -23,7 +23,7 @@ def __init__(self, corpus: Corpus, optimizer: Optimizer = SGD, epoch:int = 0, - loss: float = 1.0, + loss: float = 10000.0, optimizer_state: dict = None, scheduler_state: dict = None ): @@ -46,7 +46,7 @@ def find_learning_rate(self, **kwargs ) -> Path: loss_history = [] - best_loss = 0 + best_loss = None learning_rate_tsv = init_output_file(base_path, 'learning_rate.tsv') with open(learning_rate_tsv, 'a') as f: @@ -159,8 +159,8 @@ def train(self, if train_with_dev: train_data.extend(self.corpus.dev) - current_loss = 0 - current_score = 0 + current_loss = 0.0 + current_score = 0.0 # At any point you can hit Ctrl + C to break out of training early. try: @@ -475,9 +475,6 @@ def _evaluate_text_classifier(model: flair.nn.Model, sentences: List[Sentence], @staticmethod def load_from_checkpoint(checkpoint_file: Path, model_type: str, corpus: Corpus, optimizer: Optimizer = SGD): - if model_type not in ['SequenceTagger', 'TextClassifier']: - raise ValueError('Incorrect model type! Use one of the following: "SequenceTagger", "TextClassifier".') - if model_type == 'SequenceTagger': checkpoint = SequenceTagger.load_checkpoint(checkpoint_file) return ModelTrainer(checkpoint['model'], corpus, optimizer, epoch=checkpoint['epoch'], @@ -489,3 +486,5 @@ def load_from_checkpoint(checkpoint_file: Path, model_type: str, corpus: Corpus, return ModelTrainer(checkpoint['model'], corpus, optimizer, epoch=checkpoint['epoch'], loss=checkpoint['loss'], optimizer_state=checkpoint['optimizer_state_dict'], scheduler_state=checkpoint['scheduler_state_dict']) + + raise ValueError('Incorrect model type! Use one of the following: "SequenceTagger", "TextClassifier".') diff --git a/tests/test_model_integration.py b/tests/test_model_integration.py index b70ede8ee3..3be3354b81 100644 --- a/tests/test_model_integration.py +++ b/tests/test_model_integration.py @@ -408,9 +408,6 @@ def test_train_language_model(results_base_path, resources_path): assert (text is not None) assert (len(text) == 100) - loaded_language_model = LanguageModel.load_language_model(results_base_path / 'best-lm.pt') - assert (loaded_language_model.best_score < 100) - # clean up results directory shutil.rmtree(results_base_path, ignore_errors=True)