Skip to content

Commit

Permalink
GH-217: Refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
tabergma committed Nov 28, 2018
1 parent faaebfc commit b3df780
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 75 deletions.
51 changes: 18 additions & 33 deletions flair/models/sequence_tagger_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,7 @@ def save_checkpoint(self, model_file: Path, optimizer_state: dict, scheduler_sta

@classmethod
def load_from_file(cls, model_file: Union[str, Path]):
# suppress torch warnings:
# https://docs.python.org/3/library/warnings.html#temporarily-suppressing-warnings
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
state = torch.load(str(model_file), map_location={'cuda:0': 'cpu'})
state = SequenceTagger._load_state(model_file)

use_dropout = 0.0 if not 'use_dropout' in state.keys() else state['use_dropout']
use_word_dropout = 0.0 if not 'use_word_dropout' in state.keys() else state['use_word_dropout']
Expand All @@ -208,53 +204,42 @@ def load_from_file(cls, model_file: Union[str, Path]):
word_dropout=use_word_dropout,
locked_dropout=use_locked_dropout,
)

model.load_state_dict(state['state_dict'])
model.eval()

if torch.cuda.is_available():
model = model.cuda()

return model

@classmethod
def load_checkpoint(cls, model_file: Union[str, Path]):
# suppress torch warnings:
# https://docs.python.org/3/library/warnings.html#temporarily-suppressing-warnings
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
state = torch.load(str(model_file), map_location={'cuda:0': 'cpu'})

use_dropout = 0.0 if not 'use_dropout' in state.keys() else state['use_dropout']
use_word_dropout = 0.0 if not 'use_word_dropout' in state.keys() else state['use_word_dropout']
use_locked_dropout = 0.0 if not 'use_locked_dropout' in state.keys() else state['use_locked_dropout']
state = SequenceTagger._load_state(model_file)
model = SequenceTagger.load_from_file(model_file)

epoch = state['epoch'] if 'epoch' in state else None
loss = state['loss'] if 'loss' in state else None
optimizer_state_dict = state['optimizer_state_dict'] if 'optimizer_state_dict' in state else None
scheduler_state_dict = state['scheduler_state_dict'] if 'scheduler_state_dict' in state else None

model = SequenceTagger(
hidden_size=state['hidden_size'],
embeddings=state['embeddings'],
tag_dictionary=state['tag_dictionary'],
tag_type=state['tag_type'],
use_crf=state['use_crf'],
use_rnn=state['use_rnn'],
rnn_layers=state['rnn_layers'],
dropout=use_dropout,
word_dropout=use_word_dropout,
locked_dropout=use_locked_dropout,
)

model.load_state_dict(state['state_dict'])
model.eval()
if torch.cuda.is_available():
model = model.cuda()

return {
'model': model, 'epoch': epoch, 'loss': loss,
'optimizer_state_dict': optimizer_state_dict, 'scheduler_state_dict': scheduler_state_dict
}

@classmethod
def _load_state(cls, model_file):
# ATTENTION: suppressing torch serialization warnings. This needs to be taken out once we sort out recursive
# serialization of torch objects
# https://docs.python.org/3/library/warnings.html#temporarily-suppressing-warnings
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
if torch.cuda.is_available():
state = torch.load(str(model_file))
else:
state = torch.load(str(model_file), map_location={'cuda:0': 'cpu'})
return state

def forward_loss(self, sentences: Union[List[Sentence], Sentence]) -> torch.tensor:
features, lengths, tags = self.forward(sentences)
return self._calculate_loss(features, lengths, tags)
Expand Down
53 changes: 22 additions & 31 deletions flair/models/text_classification_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,64 +97,55 @@ def save_checkpoint(self, model_file: Path, optimizer_state: dict, scheduler_sta
torch.save(model_state, str(model_file), pickle_protocol=4)

@classmethod
def load_from_file(cls, model_file: Path):
def load_from_file(cls, model_file: [str, Path]):
"""
Loads the model from the given file.
:param model_file: the model file
:return: the loaded text classifier model
"""

# ATTENTION: suppressing torch serialization warnings. This needs to be taken out once we sort out recursive
# serialization of torch objects
# https://docs.python.org/3/library/warnings.html#temporarily-suppressing-warnings
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
if torch.cuda.is_available():
state = torch.load(str(model_file))
else:
state = torch.load(str(model_file), map_location={'cuda:0': 'cpu'})
state = TextClassifier._load_state(model_file)

model = TextClassifier(
document_embeddings=state['document_embeddings'],
label_dictionary=state['label_dictionary'],
multi_label=state['multi_label']
)

model.load_state_dict(state['state_dict'])
model.eval()

if torch.cuda.is_available():
model = model.cuda()

return model

@classmethod
def load_checkpoint(cls, model_file: Path):
# ATTENTION: suppressing torch serialization warnings. This needs to be taken out once we sort out recursive
# serialization of torch objects
# https://docs.python.org/3/library/warnings.html#temporarily-suppressing-warnings
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
if torch.cuda.is_available():
state = torch.load(str(model_file))
else:
state = torch.load(str(model_file), map_location={'cuda:0': 'cpu'})
def load_checkpoint(cls, model_file: [str, Path]):
state = TextClassifier._load_state(model_file)
model = TextClassifier.load_from_file(model_file)

epoch = state['epoch'] if 'epoch' in state else None
loss = state['loss'] if 'loss' in state else None
optimizer_state_dict = state['optimizer_state_dict'] if 'optimizer_state_dict' in state else None
scheduler_state_dict = state['scheduler_state_dict'] if 'scheduler_state_dict' in state else None

model = TextClassifier(
document_embeddings=state['document_embeddings'],
label_dictionary=state['label_dictionary'],
multi_label=state['multi_label']
)

model.load_state_dict(state['state_dict'])
model.eval()

return {
'model': model, 'epoch': epoch, 'loss': loss,
'optimizer_state_dict': optimizer_state_dict, 'scheduler_state_dict': scheduler_state_dict
}

@classmethod
def _load_state(cls, model_file):
# ATTENTION: suppressing torch serialization warnings. This needs to be taken out once we sort out recursive
# serialization of torch objects
# https://docs.python.org/3/library/warnings.html#temporarily-suppressing-warnings
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
if torch.cuda.is_available():
state = torch.load(str(model_file))
else:
state = torch.load(str(model_file), map_location={'cuda:0': 'cpu'})
return state

def forward_loss(self, sentences: Union[List[Sentence], Sentence]) -> torch.tensor:
scores = self.forward(sentences)
return self._calculate_loss(scores, sentences)
Expand Down
2 changes: 1 addition & 1 deletion flair/trainers/language_model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def __init__(self,
optimizer: Optimizer = SGD,
test_mode: bool = False,
epoch: int = 0,
loss: float = 1,
loss: float = 10000,
optimizer_state: dict = None
):
self.model: LanguageModel = model
Expand Down
13 changes: 6 additions & 7 deletions flair/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self,
corpus: Corpus,
optimizer: Optimizer = SGD,
epoch:int = 0,
loss: float = 1.0,
loss: float = 10000.0,
optimizer_state: dict = None,
scheduler_state: dict = None
):
Expand All @@ -46,7 +46,7 @@ def find_learning_rate(self,
**kwargs
) -> Path:
loss_history = []
best_loss = 0
best_loss = None

learning_rate_tsv = init_output_file(base_path, 'learning_rate.tsv')
with open(learning_rate_tsv, 'a') as f:
Expand Down Expand Up @@ -159,8 +159,8 @@ def train(self,
if train_with_dev:
train_data.extend(self.corpus.dev)

current_loss = 0
current_score = 0
current_loss = 0.0
current_score = 0.0

# At any point you can hit Ctrl + C to break out of training early.
try:
Expand Down Expand Up @@ -475,9 +475,6 @@ def _evaluate_text_classifier(model: flair.nn.Model, sentences: List[Sentence],

@staticmethod
def load_from_checkpoint(checkpoint_file: Path, model_type: str, corpus: Corpus, optimizer: Optimizer = SGD):
if model_type not in ['SequenceTagger', 'TextClassifier']:
raise ValueError('Incorrect model type! Use one of the following: "SequenceTagger", "TextClassifier".')

if model_type == 'SequenceTagger':
checkpoint = SequenceTagger.load_checkpoint(checkpoint_file)
return ModelTrainer(checkpoint['model'], corpus, optimizer, epoch=checkpoint['epoch'],
Expand All @@ -489,3 +486,5 @@ def load_from_checkpoint(checkpoint_file: Path, model_type: str, corpus: Corpus,
return ModelTrainer(checkpoint['model'], corpus, optimizer, epoch=checkpoint['epoch'],
loss=checkpoint['loss'], optimizer_state=checkpoint['optimizer_state_dict'],
scheduler_state=checkpoint['scheduler_state_dict'])

raise ValueError('Incorrect model type! Use one of the following: "SequenceTagger", "TextClassifier".')
3 changes: 0 additions & 3 deletions tests/test_model_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,9 +408,6 @@ def test_train_language_model(results_base_path, resources_path):
assert (text is not None)
assert (len(text) == 100)

loaded_language_model = LanguageModel.load_language_model(results_base_path / 'best-lm.pt')
assert (loaded_language_model.best_score < 100)

# clean up results directory
shutil.rmtree(results_base_path, ignore_errors=True)

Expand Down

0 comments on commit b3df780

Please sign in to comment.